from torchvision.ops import box_convert

def remove_bd_from_gt(gt_boxes, gt_labels, gt_poison_masks, gt_target_ids):
    # 1) keep boxes that are not poisoned OR have a positive target id
    keep = (gt_poison_masks == 0) | (gt_target_ids > 0)

    gt_boxes = gt_boxes[keep]
    gt_labels = gt_labels[keep]
    gt_poison_masks = gt_poison_masks[keep]
    gt_target_ids = gt_target_ids[keep]

    # 2) relabel the remaining poisoned ones
    relabel = (gt_poison_masks == 1) & (gt_target_ids > 0)
    gt_labels[relabel] = gt_target_ids[relabel]

    return gt_boxes, gt_labels, gt_poison_masks, gt_target_ids


def format_output(outputs, target, img, current_box_format, required_box_format, remove_bd=False):
    
    sub_results = []
    for i in range(len(outputs)):
        pred_boxes = outputs[i]["boxes"].cpu()
        pred_labels = outputs[i]["labels"].cpu()
        pred_scores = outputs[i]["scores"].cpu()

        gt_boxes = target[i]["bbox"].cpu()
        gt_labels = target[i]["category_id"].cpu()
        gt_poison_masks = target[i]["poison_mask"].cpu()
        gt_target_ids = target[i]["target_id"].cpu()

        # Check if any category_id is 0
        if (gt_labels == 0).any():
            print(f"Warning: Found gt_labels with value 0 in image {i}. This may indicate an issue with the dataset.")
            print(f"GT Labels: {gt_labels.tolist()}")
            print(f"GT Boxes: {gt_boxes.tolist()}")
            print(f"GT Poison Masks: {gt_poison_masks.tolist()}")
            print(f"GT Target IDs: {gt_target_ids.tolist()}")

            raise ValueError("Ground truth labels contain 0, which is not expected in this context.")

        # Check if any pred_labels are 0
        if (pred_labels == 0).any():
            print(f"Warning: Found pred_labels with value 0 in image {i}. This may indicate an issue with the model or data.")
            print(f"Predicted labels: {pred_labels.tolist()}")
            print(f"Pred Scores: {pred_scores.tolist()}")

            raise ValueError("Predicted labels contain 0, which is not expected in this context.")

        if remove_bd:
            gt_boxes, gt_labels, gt_poison_masks, gt_target_ids = remove_bd_from_gt(
                gt_boxes, gt_labels, gt_poison_masks, gt_target_ids
            )

        # Convert bounding boxes to COCO format (x_min, y_min, width, height)
        pred_boxes = box_convert(pred_boxes, current_box_format, required_box_format)
        gt_boxes = box_convert(gt_boxes, current_box_format, required_box_format)

        sub_results.append({
            "img_width": img[i].shape[-1],
            "img_height": img[i].shape[-2],
            "pred_boxes": pred_boxes.tolist(),
            "pred_labels": pred_labels.tolist(),
            "pred_scores": pred_scores.tolist(),
            "gt_boxes": gt_boxes.tolist(),
            "gt_labels": gt_labels.tolist(),
            "gt_poison_masks": gt_poison_masks.tolist(),
            "gt_target_ids": gt_target_ids.tolist(),
        })

    return sub_results