# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from collections import defaultdict
from typing import Dict, List
from monai.metrics import HausdorffDistanceMetric

import torch
import torch.distributed
import torch.nn as nn
import torch.nn.functional as F
import logging

from training.trainer import CORE_LOSS_KEY

from training.utils.distributed import get_world_size, is_dist_avail_and_initialized
from training.utils.boundary_utils import get_boundary_from_masks, save_boundary_visualization, save_boundary_visualization_simple
from training.utils.spatial_utils import RelativePositionLoss


def boundary_loss(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    num_objects: float,
    bce_weight: float = 1,
    dice_weight: float = 0.5
) -> torch.Tensor:
    """
    Computes a composite loss for boundary prediction, combining BCE and Dice loss.

    Args:
        inputs (torch.Tensor): The model's raw logits for boundary prediction.
                               Shape: (B, C, H, W).
        targets (torch.Tensor): The ground truth boundary masks.
                                Shape: (B, C, H, W).
        num_objects (float): The total number of objects to normalize the loss.
        bce_weight (float): The weight for the BCE loss component.
        dice_weight (float): The weight for the Dice loss component.

    Returns:
        torch.Tensor: A scalar tensor representing the combined boundary loss.
    """
    # 1. Binary Cross-Entropy Loss
    # BCEWithLogitsLoss is numerically more stable than manual sigmoid + BCE.
    # It handles its own normalization internally, so we just take the mean.
    bce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='mean')

    # 2. Dice Loss
    # We can reuse the existing dice_loss function. Since boundary prediction is a
    # single-mask problem per image, we set loss_on_multimask=False.
    # The existing dice_loss already handles normalization by num_objects.
    # d_loss = dice_loss(inputs, targets, num_objects, loss_on_multimask=False)

    # 3. Combine the losses
    # Note: We need to be careful with double normalization. bce_loss is a mean over all
    # pixels, while dice_loss is normalized by num_objects. For consistency, let's
    # stick to the normalization scheme of the existing functions.
    # The BCE loss here is an average per-pixel loss. Multiplying by a weight should be fine.

    combined_loss = bce_weight * bce_loss

    return combined_loss


def _generate_boundary_from_mask_gpu(mask: torch.Tensor, kernel_size: int = 3) -> torch.Tensor:
    """
    在GPU上使用F.conv2d高效地从二进制掩码动态生成边界。
    这是一个内部辅助函数。

    Args:
        mask (torch.Tensor): 形状为 (B, 1, H, W) 的目标掩码，值为0或1。
        kernel_size (int): 用于形态学操作的核的大小。

    Returns:
        torch.Tensor: 形状为 (B, 1, H, W) 的边界真值。
    """
    # 确保掩码是float类型
    mask_float = mask.float()

    # 准备形态学操作的核
    # F.conv2d可以模拟膨胀和腐蚀
    padding = (kernel_size - 1) // 2
    kernel = torch.ones(1, 1, kernel_size, kernel_size, dtype=torch.float32, device=mask.device)

    # 模拟膨胀：卷积后，任何非零值都表示原始像素或其邻域内有1
    dilated = F.conv2d(mask_float, kernel, padding=padding)
    dilated = (dilated > 0).float()  # 阈值化处理，得到标准的膨胀结果

    # 模拟腐蚀：需要先反转掩码，再进行膨胀，最后再反转回来
    # 腐蚀等价于：NOT(DILATE(NOT(mask)))
    eroded = 1 - F.conv2d(1 - mask_float, kernel, padding=padding)
    eroded = (eroded > 0).float()  # 阈值化处理，得到标准的腐蚀结果

    # 形态学梯度 = 膨胀 - 腐蚀
    boundary = dilated - eroded
    return boundary


def _convert_3channel_to_binary_mask(
    image_tensor: torch.Tensor,
    threshold: float = 0.5
) -> torch.Tensor:
    """
    将三通道图像转换为单通道二值掩码。

    Args:
        image_tensor (torch.Tensor): 输入的三通道图像，形状 (B, 3, H, W)。
                                     假定像素值在 [0, 1] 范围内。
        threshold (float): 用于二值化的阈值。
                           所有大于此阈值的灰度像素将被视为1（目标）。

    Returns:
        torch.Tensor: 单通道的二值掩码，形状 (B, 1, H, W)。
    """
    # 检查输入形状
    if image_tensor.dim() != 4 or image_tensor.shape[1] != 3:
        raise ValueError(
            f"Expected a 4D tensor with 3 channels, but got shape {image_tensor.shape}"
        )
    weights = torch.tensor([0.299, 0.587, 0.114], device=image_tensor.device).view(1, 3, 1, 1)

    # 将3通道图像转换为单通道灰度图
    # (B, 3, H, W) * (1, 3, 1, 1) -> (B, 3, H, W), then sum over channel dim -> (B, 1, H, W)
    grayscale_image = (image_tensor * weights).sum(dim=1, keepdim=True)

    # 对灰度图进行阈值处理以获得二值掩码
    binary_mask = (grayscale_image > threshold).float()

    return binary_mask


def slice_boundary_loss(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    num_objects: int,
    boundary_kernel_size: int = 3,
    binarization_threshold: float = 0.5
) -> torch.Tensor:
    """
    计算切片级别的边界损失。

    Args:
        inputs (torch.Tensor): 模型的边界预测 logits, 形状为 (N, 1, H, W)。
        targets (torch.Tensor): 分割的真值掩码, 形状为 (N, 1, H, W)。
        num_objects (int): 批次中的对象数量，用于归一化。
        boundary_kernel_size (int): 用于生成GT边界的形态学核的大小。

    Returns:
        torch.Tensor: 一个标量的边界损失值。
    """
    # 确保 inputs 是单通道的
    if inputs.shape[1] != 1:
        raise ValueError(
            f"This function expects 'inputs' (predictions) to have a single channel, "
            f"but got {inputs.shape[1]} channels."
        )

    # --- 关键修改：处理3通道图像 targets ---
    # 检查 targets 的通道数
    if targets.shape[1] == 3:
        # 如果 targets 是3通道图像，则将其转换为单通道二值掩码
        with torch.no_grad():  # 转换过程不应影响梯度
            target_mask = _convert_3channel_to_binary_mask(
                targets.float(), threshold=binarization_threshold
            )
    elif targets.shape[1] == 1:
        # 如果 targets 已经是单通道，则直接使用
        target_mask = targets.float()
    else:
        raise ValueError(
            f"Targets must have 1 or 3 channels, but got {targets.shape[1]}."
        )

    # 1. 从真值掩码 (targets) 生成真值边界
    # targets 是分割掩码，需要先生成边界
    with torch.no_grad():  # 确保这部分操作不计算梯度
        target_boundaries = _generate_boundary_from_mask_gpu(target_mask, boundary_kernel_size)

    # 2. 计算边界预测 (inputs) 和真值边界 (target_boundaries) 之间的损失
    loss = F.binary_cross_entropy_with_logits(inputs, target_boundaries, reduction="none")

    return loss.mean([1, 2, 3]).sum() / num_objects


def dice_loss(inputs, targets, num_objects, loss_on_multimask=False):
    """
    Compute the DICE loss, similar to generalized IOU for masks
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        num_objects: Number of objects in the batch
        loss_on_multimask: True if multimask prediction is enabled
    Returns:
        Dice loss tensor
    """
    inputs = inputs.sigmoid()
    if loss_on_multimask:
        # inputs and targets are [N, M, H, W] where M corresponds to multiple predicted masks
        assert inputs.dim() == 4 and targets.dim() == 4
        # flatten spatial dimension while keeping multimask channel dimension
        inputs = inputs.flatten(2)
        targets = targets.flatten(2)
        numerator = 2 * (inputs * targets).sum(-1)
    else:
        inputs = inputs.flatten(1)
        numerator = 2 * (inputs * targets).sum(1)
    denominator = inputs.sum(-1) + targets.sum(-1)
    loss = 1 - (numerator + 1) / (denominator + 1)
    if loss_on_multimask:
        return loss / num_objects
    return loss.sum() / num_objects


def sigmoid_focal_loss(
    inputs,
    targets,
    num_objects,
    alpha: float = 0.25,
    gamma: float = 2,
    loss_on_multimask=False,
):
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        num_objects: Number of objects in the batch
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples. Default = -1 (no weighting).
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
        loss_on_multimask: True if multimask prediction is enabled
    Returns:
        focal loss tensor
    """
    prob = inputs.sigmoid()
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = prob * targets + (1 - prob) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    if loss_on_multimask:
        # loss is [N, M, H, W] where M corresponds to multiple predicted masks
        assert loss.dim() == 4
        return loss.flatten(2).mean(-1) / num_objects  # average over spatial dims
    return loss.mean(1).sum() / num_objects


def iou_loss(
    inputs, targets, pred_ious, num_objects, loss_on_multimask=False, use_l1_loss=False
):
    """
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        pred_ious: A float tensor containing the predicted IoUs scores per mask
        num_objects: Number of objects in the batch
        loss_on_multimask: True if multimask prediction is enabled
        use_l1_loss: Whether to use L1 loss is used instead of MSE loss
    Returns:
        IoU loss tensor
    """
    assert inputs.dim() == 4 and targets.dim() == 4
    pred_mask = inputs.flatten(2) > 0
    gt_mask = targets.flatten(2) > 0
    area_i = torch.sum(pred_mask & gt_mask, dim=-1).float()
    area_u = torch.sum(pred_mask | gt_mask, dim=-1).float()
    actual_ious = area_i / torch.clamp(area_u, min=1.0)

    if use_l1_loss:
        loss = F.l1_loss(pred_ious, actual_ious, reduction="none")
    else:
        loss = F.mse_loss(pred_ious, actual_ious, reduction="none")
    if loss_on_multimask:
        return loss / num_objects
    return loss.sum() / num_objects


def spatial_loss(pos_logits, pos_target):
    if pos_logits == None:
        logging.warning(f"pos_logits is None")
        exit(0)

    B_stage, Win, C = pos_logits.shape

    position_loss_fn = nn.CrossEntropyLoss()
    loss_pos = position_loss_fn(
        pos_logits.view(B_stage * Win, C),  # (B*Win, C)
        pos_target.view(B_stage * Win).long()  # (B*Win,)
    )
    # logging.info(f"position_loss_fn:{loss_pos}")
    return loss_pos


def consistency_loss(current_token, prev_token):
    if current_token == None or prev_token == None:
        return 0

    prev_token_detached = prev_token.detach()

    # (dim=-1 确保在特征维度上计算)
    consistency_loss = 1.0 - torch.nn.functional.cosine_similarity(
        current_token, prev_token_detached, dim=-1
    )

    return consistency_loss.mean()


class MultiStepMultiMasksAndIous(nn.Module):
    def __init__(
        self,
        weight_dict,
        focal_alpha=0.25,
        focal_gamma=2,
        supervise_all_iou=False,
        iou_use_l1_loss=False,
        pred_obj_scores=False,
        focal_gamma_obj_score=0.0,
        focal_alpha_obj_score=-1,
    ):
        """
        This class computes the multi-step multi-mask and IoU losses.
        Args:
            weight_dict: dict containing weights for focal, dice, iou losses
            focal_alpha: alpha for sigmoid focal loss
            focal_gamma: gamma for sigmoid focal loss
            supervise_all_iou: if True, back-prop iou losses for all predicted masks
            iou_use_l1_loss: use L1 loss instead of MSE loss for iou
            pred_obj_scores: if True, compute loss for object scores
            focal_gamma_obj_score: gamma for sigmoid focal loss on object scores
            focal_alpha_obj_score: alpha for sigmoid focal loss on object scores
        """

        super().__init__()
        self.weight_dict = weight_dict
        self.focal_alpha = focal_alpha
        self.focal_gamma = focal_gamma
        assert "loss_mask" in self.weight_dict
        assert "loss_dice" in self.weight_dict
        assert "loss_iou" in self.weight_dict
        if "loss_class" not in self.weight_dict:
            self.weight_dict["loss_class"] = 0.0

        self.focal_alpha_obj_score = focal_alpha_obj_score
        self.focal_gamma_obj_score = focal_gamma_obj_score
        self.supervise_all_iou = supervise_all_iou
        self.iou_use_l1_loss = iou_use_l1_loss
        self.pred_obj_scores = pred_obj_scores

    def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor):
        assert len(outs_batch) == len(targets_batch)
        num_objects = torch.tensor(
            (targets_batch.shape[1]), device=targets_batch.device, dtype=torch.float
        )  # Number of objects is fixed within a batch
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_objects)
        num_objects = torch.clamp(num_objects / get_world_size(), min=1).item()

        losses = defaultdict(int)
        for outs, targets in zip(outs_batch, targets_batch):
            cur_losses = self._forward(outs, targets, num_objects)
            for k, v in cur_losses.items():
                losses[k] += v

        return losses

    def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects):
        """
        Compute the losses related to the masks: the focal loss and the dice loss.
        and also the MAE or MSE loss between predicted IoUs and actual IoUs.

        Here "multistep_pred_multimasks_high_res" is a list of multimasks (tensors
        of shape [N, M, H, W], where M could be 1 or larger, corresponding to
        one or multiple predicted masks from a click.

        We back-propagate focal, dice losses only on the prediction channel
        with the lowest focal+dice loss between predicted mask and ground-truth.
        If `supervise_all_iou` is True, we backpropagate ious losses for all predicted masks.
        """

        target_masks = targets.unsqueeze(1).float()
        assert target_masks.dim() == 4  # [N, 1, H, W]
        src_masks_list = outputs["multistep_pred_multimasks_high_res"]
        ious_list = outputs["multistep_pred_ious"]
        object_score_logits_list = outputs["multistep_object_score_logits"]

        assert len(src_masks_list) == len(ious_list)
        assert len(object_score_logits_list) == len(ious_list)

        # accumulate the loss over prediction steps
        losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0}
        for src_masks, ious, object_score_logits in zip(
            src_masks_list, ious_list, object_score_logits_list
        ):
            self._update_losses(
                losses, src_masks, target_masks, ious, num_objects, object_score_logits
            )
        losses[CORE_LOSS_KEY] = self.reduce_loss(losses)
        return losses

    def _update_losses(
        self, losses, src_masks, target_masks, ious, num_objects, object_score_logits
    ):
        target_masks = target_masks.expand_as(src_masks)
        # get focal, dice and iou loss on all output masks in a prediction step
        loss_multimask = sigmoid_focal_loss(
            src_masks,
            target_masks,
            num_objects,
            alpha=self.focal_alpha,
            gamma=self.focal_gamma,
            loss_on_multimask=True,
        )
        loss_multidice = dice_loss(
            src_masks, target_masks, num_objects, loss_on_multimask=True
        )
        if not self.pred_obj_scores:
            loss_class = torch.tensor(
                0.0, dtype=loss_multimask.dtype, device=loss_multimask.device
            )
            target_obj = torch.ones(
                loss_multimask.shape[0],
                1,
                dtype=loss_multimask.dtype,
                device=loss_multimask.device,
            )
        else:
            target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[
                ..., None
            ].float()
            loss_class = sigmoid_focal_loss(
                object_score_logits,
                target_obj,
                num_objects,
                alpha=self.focal_alpha_obj_score,
                gamma=self.focal_gamma_obj_score,
            )

        loss_multiiou = iou_loss(
            src_masks,
            target_masks,
            ious,
            num_objects,
            loss_on_multimask=True,
            use_l1_loss=self.iou_use_l1_loss,
        )
        assert loss_multimask.dim() == 2
        assert loss_multidice.dim() == 2
        assert loss_multiiou.dim() == 2
        if loss_multimask.size(1) > 1:
            # take the mask indices with the smallest focal + dice loss for back propagation
            loss_combo = (
                loss_multimask * self.weight_dict["loss_mask"]
                + loss_multidice * self.weight_dict["loss_dice"]
            )
            best_loss_inds = torch.argmin(loss_combo, dim=-1)
            batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device)
            loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1)
            loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1)
            # calculate the iou prediction and slot losses only in the index
            # with the minimum loss for each mask (to be consistent w/ SAM)
            if self.supervise_all_iou:
                loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1)
            else:
                loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1)
        else:
            loss_mask = loss_multimask
            loss_dice = loss_multidice
            loss_iou = loss_multiiou

        # backprop focal, dice and iou loss only if obj present
        loss_mask = loss_mask * target_obj
        loss_dice = loss_dice * target_obj
        loss_iou = loss_iou * target_obj

        # sum over batch dimension (note that the losses are already divided by num_objects)
        losses["loss_mask"] += loss_mask.sum()
        losses["loss_dice"] += loss_dice.sum()
        losses["loss_iou"] += loss_iou.sum()
        losses["loss_class"] += loss_class

    def reduce_loss(self, losses):
        reduced_loss = 0.0
        for loss_key, weight in self.weight_dict.items():
            if loss_key not in losses:
                raise ValueError(f"{type(self)} doesn't compute {loss_key}")
            if weight != 0:
                reduced_loss += losses[loss_key] * weight

        return reduced_loss


class MultiStepMultiMasksAndIous4Val(nn.Module):
    def __init__(
        self,
        weight_dict,
        focal_alpha=0.25,
        focal_gamma=2,
        supervise_all_iou=False,
        iou_use_l1_loss=False,
        pred_obj_scores=False,
        focal_gamma_obj_score=0.0,
        focal_alpha_obj_score=-1,
    ):
        """
        This class computes the multi-step multi-mask and IoU losses.
        Args:
            weight_dict: dict containing weights for focal, dice, iou losses
            focal_alpha: alpha for sigmoid focal loss
            focal_gamma: gamma for sigmoid focal loss
            supervise_all_iou: if True, back-prop iou losses for all predicted masks
            iou_use_l1_loss: use L1 loss instead of MSE loss for iou
            pred_obj_scores: if True, compute loss for object scores
            focal_gamma_obj_score: gamma for sigmoid focal loss on object scores
            focal_alpha_obj_score: alpha for sigmoid focal loss on object scores
        """

        super().__init__()
        self.weight_dict = weight_dict
        self.focal_alpha = focal_alpha
        self.focal_gamma = focal_gamma
        assert "loss_mask" in self.weight_dict
        assert "loss_dice" in self.weight_dict
        assert "loss_iou" in self.weight_dict
        if "loss_class" not in self.weight_dict:
            self.weight_dict["loss_class"] = 0.0

        self.focal_alpha_obj_score = focal_alpha_obj_score
        self.focal_gamma_obj_score = focal_gamma_obj_score
        self.supervise_all_iou = supervise_all_iou
        self.iou_use_l1_loss = iou_use_l1_loss
        self.pred_obj_scores = pred_obj_scores

        self.new_dice = 0
        self.new_iou = 0

    def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor):
        assert len(outs_batch) == len(targets_batch)
        num_objects = torch.tensor(
            (targets_batch.shape[1]), device=targets_batch.device, dtype=torch.float
        )  # Number of objects is fixed within a batch
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_objects)
        num_objects = torch.clamp(num_objects / get_world_size(), min=1).item()

        losses = defaultdict(int)
        self.new_dice = 0
        self.new_iou = 0
        self.new_hd95 = 0

        for outs, targets in zip(outs_batch, targets_batch):
            cur_losses = self._forward(outs, targets, num_objects)
            for k, v in cur_losses.items():
                losses[k] += v

        outs_len = len(outs_batch)
        logging.info(f"outs_len:{outs_len}")
        metric_dict = {'dice': self.new_dice/outs_len, 'iou': self.new_iou/outs_len, 'hd95': self.new_hd95/outs_len}
        logging.info(f"metric_dict:{metric_dict}")
        return losses, metric_dict

    def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects):
        """
        Compute the losses related to the masks: the focal loss and the dice loss.
        and also the MAE or MSE loss between predicted IoUs and actual IoUs.

        Here "multistep_pred_multimasks_high_res" is a list of multimasks (tensors
        of shape [N, M, H, W], where M could be 1 or larger, corresponding to
        one or multiple predicted masks from a click.

        We back-propagate focal, dice losses only on the prediction channel
        with the lowest focal+dice loss between predicted mask and ground-truth.
        If `supervise_all_iou` is True, we backpropagate ious losses for all predicted masks.
        """

        target_masks = targets.unsqueeze(1).float()
        assert target_masks.dim() == 4  # [N, 1, H, W]
        src_masks_list = outputs["multistep_pred_multimasks_high_res"]
        ious_list = outputs["multistep_pred_ious"]
        object_score_logits_list = outputs["multistep_object_score_logits"]

        assert len(src_masks_list) == len(ious_list)
        assert len(object_score_logits_list) == len(ious_list)

        # accumulate the loss over prediction steps
        losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0}
        for src_masks, ious, object_score_logits in zip(
            src_masks_list, ious_list, object_score_logits_list
        ):
            self._update_losses(
                losses, src_masks, target_masks, ious, num_objects, object_score_logits
            )
        losses[CORE_LOSS_KEY] = self.reduce_loss(losses)

        # calculate loss metrics here
        final_pred_masks = src_masks_list[-1]
        # 使用 torch.no_grad() 来确保指标计算不影响梯度
        with torch.no_grad():
            metrics = self.calculate_metrics(final_pred_masks, target_masks)

        logging.info(f"new metrics:{metrics}")
        self.new_dice += metrics['dice']
        self.new_iou += metrics['iou']
        self.new_hd95 += metrics['hd95']

        return losses

    def calculate_metrics(self, preds: torch.Tensor, targets: torch.Tensor, threshold: float = 0.5, epsilon: float = 1e-6) -> Dict[str, torch.Tensor]:
        """
        计算Dice、IoU和HD95指标，能够处理SAM的多掩码输出。

        Args:
            preds (torch.Tensor): 预测的掩码，形状为 [N, M, H, W]。通常是logits。
            targets (torch.Tensor): 真实的掩码，形状为 [N, 1, H, W]。
            threshold (float): 用于将预测概率二值化的阈值。
            epsilon (float): 用于防止除以零的小常数。

        Returns:
            Dict[str, torch.Tensor]: 包含平均dice、iou和hd95得分的字典。
        """
        # 1. 预处理预测值：通过 sigmoid 转换为概率，然后根据阈值二值化
        # [N, M, H, W] -> [N, M, H, W]
        binary_preds = (torch.sigmoid(preds) > threshold).float()

        # 2. 准备目标张量：将其扩展以匹配预测的形状，便于向量化计算
        # [N, 1, H, W] -> [N, M, H, W]
        targets_expanded = targets.expand_as(binary_preds)

        # 3. 计算所有候选掩码的交集和并集
        # .sum(dim=[2, 3]) 将 H 和 W 维度相加，得到每个掩码的总像素数
        intersection = (binary_preds * targets_expanded).sum(dim=[2, 3])  # Shape: [N, M]
        pred_sum = binary_preds.sum(dim=[2, 3])  # Shape: [N, M]
        target_sum = targets_expanded.sum(dim=[2, 3])  # Shape: [N, M]
        union = pred_sum + target_sum - intersection  # Shape: [N, M]

        # 4. 计算所有候选掩码的 IoU 分数，并找到每个样本的最佳掩码
        iou_scores = (intersection + epsilon) / (union + epsilon)  # Shape: [N, M]
        best_ious, best_indices = torch.max(iou_scores, dim=1)  # Shape: [N], [N]

        # 5. 使用最佳掩码的索引，挑选出对应的交集和像素总和，以计算Dice分数
        best_intersection = torch.gather(intersection, 1, best_indices.unsqueeze(1)).squeeze(1)  # Shape: [N]
        best_pred_sum = torch.gather(pred_sum, 1, best_indices.unsqueeze(1)).squeeze(1)  # Shape: [N]
        best_target_sum = target_sum[:, 0]  # Shape: [N]

        # 6. 计算最终的Dice分数
        dice_scores = (2. * best_intersection + epsilon) / (best_pred_sum + best_target_sum + epsilon)  # Shape: [N]

        # --- 新增：HD95计算 ---
        # 7. 挑选出最佳的预测掩码
        # best_indices [N] -> [N, 1, 1, 1] for broadcasting
        best_indices_expanded = best_indices.view(-1, 1, 1, 1).expand(-1, 1, *binary_preds.shape[2:])
        # 从 [N, M, H, W] 中挑选出最佳掩码 -> [N, 1, H, W]
        best_preds_mask = torch.gather(binary_preds, 1, best_indices_expanded)

        # 8. 初始化HD95计算器
        # include_background=True 确保即使掩码为空也能进行计算
        # percentile=95 指定计算95%的豪斯多夫距离
        hd_metric = HausdorffDistanceMetric(include_background=True, percentile=95, reduction="mean")

        # 9. 计算HD95
        # monai metric需要 (N, C, H, W) 格式，这里C=1
        # 它会自动处理从tensor到numpy的转换以及轮廓提取
        # 注意：如果GPU上显存不足，可先移至CPU: best_preds_mask.cpu()
        hd_metric(y_pred=best_preds_mask, y=targets)
        hd95_score = hd_metric.aggregate()  # 获取计算结果
        # --- HD95计算结束 ---

        # 10. 计算整个批次的平均值并返回
        return {
            "dice": dice_scores.mean(),
            "iou": best_ious.mean(),
            "hd95": hd95_score  # monai返回的是tensor
        }

    def _update_losses(
        self, losses, src_masks, target_masks, ious, num_objects, object_score_logits
    ):
        target_masks = target_masks.expand_as(src_masks)
        # get focal, dice and iou loss on all output masks in a prediction step
        loss_multimask = sigmoid_focal_loss(
            src_masks,
            target_masks,
            num_objects,
            alpha=self.focal_alpha,
            gamma=self.focal_gamma,
            loss_on_multimask=True,
        )
        loss_multidice = dice_loss(
            src_masks, target_masks, num_objects, loss_on_multimask=True
        )
        if not self.pred_obj_scores:
            loss_class = torch.tensor(
                0.0, dtype=loss_multimask.dtype, device=loss_multimask.device
            )
            target_obj = torch.ones(
                loss_multimask.shape[0],
                1,
                dtype=loss_multimask.dtype,
                device=loss_multimask.device,
            )
        else:
            target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[
                ..., None
            ].float()
            loss_class = sigmoid_focal_loss(
                object_score_logits,
                target_obj,
                num_objects,
                alpha=self.focal_alpha_obj_score,
                gamma=self.focal_gamma_obj_score,
            )

        loss_multiiou = iou_loss(
            src_masks,
            target_masks,
            ious,
            num_objects,
            loss_on_multimask=True,
            use_l1_loss=self.iou_use_l1_loss,
        )
        assert loss_multimask.dim() == 2
        assert loss_multidice.dim() == 2
        assert loss_multiiou.dim() == 2
        if loss_multimask.size(1) > 1:
            # take the mask indices with the smallest focal + dice loss for back propagation
            loss_combo = (
                loss_multimask * self.weight_dict["loss_mask"]
                + loss_multidice * self.weight_dict["loss_dice"]
            )
            best_loss_inds = torch.argmin(loss_combo, dim=-1)
            batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device)
            loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1)
            loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1)
            # calculate the iou prediction and slot losses only in the index
            # with the minimum loss for each mask (to be consistent w/ SAM)
            if self.supervise_all_iou:
                loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1)
            else:
                loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1)

            best_pred_logits = src_masks[batch_inds, best_loss_inds]  # Shape: [N, H, W]
        else:
            loss_mask = loss_multimask
            loss_dice = loss_multidice
            loss_iou = loss_multiiou

            best_pred_logits = src_masks.squeeze(1)  # Shape: [N, H, W]

        # backprop focal, dice and iou loss only if obj present
        loss_mask = loss_mask * target_obj
        loss_dice = loss_dice * target_obj
        loss_iou = loss_iou * target_obj

        # sum over batch dimension (note that the losses are already divided by num_objects)
        losses["loss_mask"] += loss_mask.sum()
        losses["loss_dice"] += loss_dice.sum()
        losses["loss_iou"] += loss_iou.sum()
        losses["loss_class"] += loss_class

        # ==================== 新增：简洁的指标计算部分 ====================
        with torch.no_grad():
            gt_masks = target_masks[:, 0, :, :]
            best_pred_binary = (best_pred_logits.sigmoid() > 0.5).float()

            pred_flat = best_pred_binary.flatten(1)
            gt_flat = gt_masks.flatten(1)

            epsilon = 1e-6

            intersection = (pred_flat * gt_flat).sum(1)
            pred_sum = pred_flat.sum(1)
            gt_sum = gt_flat.sum(1)

            # 每个样本的 Dice 和 IoU 分数
            dice_score = (2. * intersection + epsilon) / (pred_sum + gt_sum + epsilon)
            iou_score = (intersection + epsilon) / (pred_sum + gt_sum - intersection + epsilon)

            # 1. 计算批次内有效物体的数量
            num_valid_objects = target_obj.sum()

            # 只有在有物体的情况下才进行累加，避免 num_valid_objects=0 时产生NaN
            if num_valid_objects > 0:
                # 2. 计算有效样本的指标总和
                sum_dice = (dice_score * target_obj.squeeze(-1)).sum()
                sum_iou = (iou_score * target_obj.squeeze(-1)).sum()

                # 3. 累加总和与计数
                losses["metric_dice"] = losses.get("metric_dice", 0.0) + sum_dice
                losses["metric_iou"] = losses.get("metric_iou", 0.0) + sum_iou
                losses["metric_count"] = losses.get("metric_count", 0.0) + num_valid_objects

                avg_dice_this_batch = (dice_score * target_obj.squeeze(-1)).sum() / num_valid_objects
                avg_iou_this_batch = (iou_score * target_obj.squeeze(-1)).sum() / num_valid_objects

                # logging.info(f"avg_dice_this_batch:{avg_dice_this_batch}, avg_iou_this_batch:{avg_iou_this_batch}")

    def reduce_loss(self, losses):
        reduced_loss = 0.0
        for loss_key, weight in self.weight_dict.items():
            if loss_key not in losses:
                raise ValueError(f"{type(self)} doesn't compute {loss_key}")
            if weight != 0:
                reduced_loss += losses[loss_key] * weight

        return reduced_loss


class MultiStepMultiMasksAndIousFull(nn.Module):
    def __init__(
        self,
        weight_dict,
        focal_alpha=0.25,
        focal_gamma=2,
        supervise_all_iou=False,
        iou_use_l1_loss=False,
        pred_obj_scores=False,
        focal_gamma_obj_score=0.0,
        focal_alpha_obj_score=-1,
    ):
        """
        This class computes the multi-step multi-mask and IoU losses.
        Args:
            weight_dict: dict containing weights for focal, dice, iou losses
            focal_alpha: alpha for sigmoid focal loss
            focal_gamma: gamma for sigmoid focal loss
            supervise_all_iou: if True, back-prop iou losses for all predicted masks
            iou_use_l1_loss: use L1 loss instead of MSE loss for iou
            pred_obj_scores: if True, compute loss for object scores
            focal_gamma_obj_score: gamma for sigmoid focal loss on object scores
            focal_alpha_obj_score: alpha for sigmoid focal loss on object scores
        """

        super().__init__()
        self.weight_dict = weight_dict
        self.focal_alpha = focal_alpha
        self.focal_gamma = focal_gamma
        assert "loss_mask" in self.weight_dict
        assert "loss_dice" in self.weight_dict
        assert "loss_iou" in self.weight_dict
        if "loss_class" not in self.weight_dict:
            self.weight_dict["loss_class"] = 0.0

        self.focal_alpha_obj_score = focal_alpha_obj_score
        self.focal_gamma_obj_score = focal_gamma_obj_score
        self.supervise_all_iou = supervise_all_iou
        self.iou_use_l1_loss = iou_use_l1_loss
        self.pred_obj_scores = pred_obj_scores

        # --- NEW: Store boundary loss weights ---
        boundary_bce_weight = 1
        boundary_dice_weight = 1
        self.boundary_bce_weight = boundary_bce_weight
        self.boundary_dice_weight = boundary_dice_weight

        # RelativePositionLoss
        self.relative_pos_loss = RelativePositionLoss(8)   # hardcode here

        # Ensure loss_boundary weight is set if not present, to avoid errors
        if "loss_boundary" not in self.weight_dict:
            self.weight_dict["loss_boundary"] = 0.0
        if "loss_spatial" not in self.weight_dict:
            self.weight_dict["loss_spatial"] = 0.0

        self.viz_counter = 0

    def forward(self, outs_batch: List[Dict], targets_batch: torch.Tensor):
        assert len(outs_batch) == len(targets_batch)
        num_objects = torch.tensor(
            (targets_batch.shape[1]), device=targets_batch.device, dtype=torch.float
        )  # Number of objects is fixed within a batch
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_objects)
        num_objects = torch.clamp(num_objects / get_world_size(), min=1).item()

        losses = defaultdict(int)

        # calculate full boundary loss here
        # --- Single-step Boundary Loss Calculation (NEW) ---
        # This part handles the one-shot prediction from the encoder's boundary head.
        # It is calculated once per sample, outside the multi-step loop.
        # Check if boundary loss is enabled and predictions are available in backbone_out
        if self.weight_dict["loss_boundary"] > 0:
            # print(f"type:{type(outs_batch)}, len:{len(outs_batch)}")
            # exit(0)

            pred_boundaries = outs_batch[0]["boundary_logits"]  # Shape (B_flat, 1, H_pred, W_pred)

            # targets_batch has shape (B_videos, N_frames, H, W). We need to flatten it.
            # Reshape it to (B_videos * N_frames, 1, H, W) to match predictions
            flat_targets = targets_batch.flatten(0, 1).unsqueeze(1).float()

            # gt_boundaries = get_boundary_from_masks(flat_targets)
            num_frames_pred = pred_boundaries.shape[0]
            num_frames_gt = flat_targets.shape[0]

            if num_frames_pred != num_frames_gt:
                # This handles cases like 64 vs 72
                min_frames = min(num_frames_pred, num_frames_gt)
                pred_boundaries = pred_boundaries[:min_frames]
                flat_targets = flat_targets[:min_frames]
                # Log a warning to notify the user about the mismatch
                if torch.distributed.get_rank() == 0:  # Only log on the main process
                    # print(f"Warning: Mismatch in frame count for boundary loss. " # TODO
                    #       f"Preds: {num_frames_pred}, GTs: {num_frames_gt}. "
                    #       f"Using {min_frames} frames for calculation.")
                    pass

                    # Now pred_boundaries and flat_targets have the same batch size
            gt_boundaries = get_boundary_from_masks(flat_targets)

            # Upsample predictions if needed
            if pred_boundaries.shape[-2:] != gt_boundaries.shape[-2:]:
                pred_boundaries = F.interpolate(
                    pred_boundaries,
                    size=gt_boundaries.shape[-2:],
                    mode="bilinear", align_corners=False
                )

            # Call the boundary_loss function
            loss_b = boundary_loss(
                pred_boundaries, gt_boundaries, num_objects,
                bce_weight=self.boundary_bce_weight,
                dice_weight=self.boundary_dice_weight
            )
            losses["loss_boundary"] += loss_b

            # # Upsample predictions if needed
            # # It's important to visualize the upsampled predictions to match GT size
            # if pred_boundaries.shape[-2:] != gt_boundaries.shape[-2:]:
            #     pred_boundaries_upsampled = F.interpolate(
            #         pred_boundaries,
            #         size=gt_boundaries.shape[-2:],
            #         mode="bilinear", align_corners=False
            #     )
            # else:
            #     pred_boundaries_upsampled = pred_boundaries
            # self.viz_counter += 1
            # save_boundary_visualization_simple(
            #     pred_logits=pred_boundaries_upsampled.clone(),  # Visualize the same upsampled logits
            #     gt_boundaries=gt_boundaries.clone(),
            #     save_path="/home/lthpc/Next/sam2_3dmed2/visualization_output",
            #     prefix="train_boundary_only_viz",
            #     batch_idx=self.viz_counter
            # )
            # logging.info(f"saved to ouput file")

        # calculate full spatial slice loss here
        # if self.weight_dict["loss_spatial"] > 0:
        #     predicted_pos_logits = outs_batch[0]["predicted_pos_logits"]
        #     relative_positions = outs_batch[0]["relative_positions"]

        #     spatial_context_loss = self.relative_pos_loss(
        #         predicted_pos_logits,
        #         relative_positions
        #     )
        #     losses["loss_spatial"] = spatial_context_loss

            # now for the loss calculation per frame

        # calculate loss per frame
        # object_tokens_history = {}
        # prev_token = None
        for outs, targets in zip(outs_batch, targets_batch):
            cur_losses = self._forward(outs, targets, num_objects, loss_boundary=loss_b)
            for k, v in cur_losses.items():
                losses[k] += v

        return losses

    def _forward(self, outputs: Dict, targets: torch.Tensor, num_objects, loss_boundary=0):
        """
        Compute the losses related to the masks: the focal loss and the dice loss.
        and also the MAE or MSE loss between predicted IoUs and actual IoUs.

        Here "multistep_pred_multimasks_high_res" is a list of multimasks (tensors
        of shape [N, M, H, W], where M could be 1 or larger, corresponding to
        one or multiple predicted masks from a click.

        We back-propagate focal, dice losses only on the prediction channel
        with the lowest focal+dice loss between predicted mask and ground-truth.
        If `supervise_all_iou` is True, we backpropagate ious losses for all predicted masks.
        """

        target_masks = targets.unsqueeze(1).float()
        assert target_masks.dim() == 4  # [N, 1, H, W]
        src_masks_list = outputs["multistep_pred_multimasks_high_res"]
        ious_list = outputs["multistep_pred_ious"]
        object_score_logits_list = outputs["multistep_object_score_logits"]
        pos_logits = outputs["pos_logits"]
        pos_target = outputs["pos_target"]

        assert len(src_masks_list) == len(ious_list)
        assert len(object_score_logits_list) == len(ious_list)

        # accumulate the loss over prediction steps
        losses = {"loss_mask": 0, "loss_dice": 0, "loss_iou": 0, "loss_class": 0,
                  "loss_boundary": loss_boundary, "loss_spatial": 0}
        for src_masks, ious, object_score_logits in zip(
            src_masks_list, ious_list, object_score_logits_list
        ):
            self._update_losses(
                losses, src_masks, target_masks, ious, num_objects, object_score_logits, pos_logits, pos_target
            )

        losses[CORE_LOSS_KEY] = self.reduce_loss(losses)
        return losses

    def _update_losses(
        self, losses, src_masks, target_masks, ious, num_objects, object_score_logits, pos_logits, pos_target
    ):
        target_masks = target_masks.expand_as(src_masks)
        # get focal, dice and iou loss on all output masks in a prediction step
        loss_multimask = sigmoid_focal_loss(
            src_masks,
            target_masks,
            num_objects,
            alpha=self.focal_alpha,
            gamma=self.focal_gamma,
            loss_on_multimask=True,
        )
        loss_multidice = dice_loss(
            src_masks, target_masks, num_objects, loss_on_multimask=True
        )
        if not self.pred_obj_scores:
            loss_class = torch.tensor(
                0.0, dtype=loss_multimask.dtype, device=loss_multimask.device
            )
            target_obj = torch.ones(
                loss_multimask.shape[0],
                1,
                dtype=loss_multimask.dtype,
                device=loss_multimask.device,
            )
        else:
            target_obj = torch.any((target_masks[:, 0] > 0).flatten(1), dim=-1)[
                ..., None
            ].float()
            loss_class = sigmoid_focal_loss(
                object_score_logits,
                target_obj,
                num_objects,
                alpha=self.focal_alpha_obj_score,
                gamma=self.focal_gamma_obj_score,
            )

        loss_multiiou = iou_loss(
            src_masks,
            target_masks,
            ious,
            num_objects,
            loss_on_multimask=True,
            use_l1_loss=self.iou_use_l1_loss,
        )

        loss_spatial = spatial_loss(
            pos_logits,
            pos_target
        )

        assert loss_multimask.dim() == 2
        assert loss_multidice.dim() == 2
        assert loss_multiiou.dim() == 2
        if loss_multimask.size(1) > 1:
            # take the mask indices with the smallest focal + dice loss for back propagation
            loss_combo = (
                loss_multimask * self.weight_dict["loss_mask"]
                + loss_multidice * self.weight_dict["loss_dice"]
            )
            best_loss_inds = torch.argmin(loss_combo, dim=-1)
            batch_inds = torch.arange(loss_combo.size(0), device=loss_combo.device)
            loss_mask = loss_multimask[batch_inds, best_loss_inds].unsqueeze(1)
            loss_dice = loss_multidice[batch_inds, best_loss_inds].unsqueeze(1)
            # calculate the iou prediction and slot losses only in the index
            # with the minimum loss for each mask (to be consistent w/ SAM)
            if self.supervise_all_iou:
                loss_iou = loss_multiiou.mean(dim=-1).unsqueeze(1)
            else:
                loss_iou = loss_multiiou[batch_inds, best_loss_inds].unsqueeze(1)

                # loss_boundary_decoder = loss_multiboundary[batch_inds, best_loss_inds].unsqueeze(1)

                best_pred_logits = src_masks[batch_inds, best_loss_inds]  # Shape: [N, H, W]
        else:
            loss_mask = loss_multimask
            loss_dice = loss_multidice
            loss_iou = loss_multiiou

            best_pred_logits = src_masks.squeeze(1)  # Shape: [N, H, W]

        # backprop focal, dice and iou loss only if obj present
        loss_mask = loss_mask * target_obj
        loss_dice = loss_dice * target_obj
        loss_iou = loss_iou * target_obj
        loss_spatial = loss_spatial * target_obj

        # loss_boundary_decoder = loss_boundary_decoder * target_obj

        # sum over batch dimension (note that the losses are already divided by num_objects)
        losses["loss_mask"] += loss_mask.sum()
        losses["loss_dice"] += loss_dice.sum()
        losses["loss_iou"] += loss_iou.sum()
        losses["loss_spatial"] += loss_spatial.sum()
        losses["loss_class"] += loss_class
        # losses["loss_boundary_decoder"] += loss_boundary_decoder.sum()

        # ==================== 新增：简洁的指标计算部分 ====================
        if not self.training:
            with torch.no_grad():
                gt_masks = target_masks[:, 0, :, :]
                best_pred_binary = (best_pred_logits.sigmoid() > 0.5).float()

                pred_flat = best_pred_binary.flatten(1)
                gt_flat = gt_masks.flatten(1)

                epsilon = 1e-6

                intersection = (pred_flat * gt_flat).sum(1)
                pred_sum = pred_flat.sum(1)
                gt_sum = gt_flat.sum(1)

                # 每个样本的 Dice 和 IoU 分数
                dice_score = (2. * intersection + epsilon) / (pred_sum + gt_sum + epsilon)
                iou_score = (intersection + epsilon) / (pred_sum + gt_sum - intersection + epsilon)

                # 1. 计算批次内有效物体的数量
                num_valid_objects = target_obj.sum()

                # 只有在有物体的情况下才进行累加，避免 num_valid_objects=0 时产生NaN
                if num_valid_objects > 0:
                    # 2. 计算有效样本的指标总和
                    sum_dice = (dice_score * target_obj.squeeze(-1)).sum()
                    sum_iou = (iou_score * target_obj.squeeze(-1)).sum()

                    # 3. 累加总和与计数
                    losses["metric_dice"] = losses.get("metric_dice", 0.0) + sum_dice
                    losses["metric_iou"] = losses.get("metric_iou", 0.0) + sum_iou
                    losses["metric_count"] = losses.get("metric_count", 0.0) + num_valid_objects

                    avg_dice_this_batch = (dice_score * target_obj.squeeze(-1)).sum() / num_valid_objects
                    avg_iou_this_batch = (iou_score * target_obj.squeeze(-1)).sum() / num_valid_objects

                    # logging.info(f"avg_dice_this_batch:{avg_dice_this_batch}, avg_iou_this_batch:{avg_iou_this_batch}")

    def reduce_loss(self, losses):
        reduced_loss = 0.0
        for loss_key, weight in self.weight_dict.items():
            if loss_key not in losses:
                raise ValueError(f"{type(self)} doesn't compute {loss_key}")
            if weight != 0:
                # print(f"now add for:{loss_key}, weight:{weight}, value:{losses[loss_key]}, final_value:{weight*losses[loss_key]}")
                reduced_loss += losses[loss_key] * weight

        # print(f"reduced_loss: {reduced_loss}")
        # exit(0)
        return reduced_loss
