import torch

def compute_chamfer_loss_torch(
    pred: torch.Tensor,       # (N, 2) tensor
    gt: torch.Tensor,         # (N, 2) tensor
    ee_mask: torch.Tensor,    # (N,) tensor (0/1 mask)
    neigh_indices: torch.Tensor,  # (N,) tensor (indices of nearest neighbors)
    loss_weight: float,      # scalar weight
):
    # 1. Filter valid points (where ee_mask == 0)
    valid_mask = (ee_mask == 0)
    pred_valid = pred[valid_mask]      # (M, 2)
    gt_valid = gt[valid_mask]          # (M, 2)
    neigh_indices_valid = neigh_indices[valid_mask]  # (M,)

    # 2. Compute nearest-neighbor distances (pred -> gt)
    min_pred = pred[neigh_indices_valid]  # (M, 2)
    dists = torch.norm(min_pred - gt_valid, dim=1)  # (M,)
    squared_dists = dists ** 2  # (M,)

    # 3. Normalize by number of valid points and apply weight
    num_valid = valid_mask.sum().float()
    loss = (loss_weight * squared_dists).sum() / num_valid

    return loss


def compute_chamfer_loss_torch_bidirectional(
        pred: torch.Tensor,  # (N, 2)
        gt: torch.Tensor,  # (N, 2)
        ee_mask: torch.Tensor,  # (N,) binary mask (0=valid)
        neigh_indices: torch.Tensor,  # (N,) gt -> pred 的最近邻索引
        neigh_indices_pred_to_gt: torch.Tensor,  # (N,) pred -> gt 的最近邻索引
        loss_weight: float,
):
    # 1. 过滤无效点（ee_mask == 0 的点）
    valid_mask = (ee_mask == 0)
    if not valid_mask.any():
        return torch.tensor(0.0, device=pred.device)

    pred_valid = pred[valid_mask]  # (M, 2)
    gt_valid = gt[valid_mask]  # (M, 2)

    # 2. 计算 pred -> gt 的距离（使用预计算的 neigh_indices_pred_to_gt）
    # 注意：neigh_indices_pred_to_gt 是 pred 到 gt 的映射，需过滤无效点
    pred_nn_in_gt = gt[neigh_indices_pred_to_gt[valid_mask]]  # (M, 2)
    dist_pred2gt = torch.norm(pred_valid - pred_nn_in_gt, dim=1)  # (M,)
    loss = (dist_pred2gt  **  2).sum()

    # 3. 计算 gt -> pred 的距离（使用预计算的 neigh_indices）
    # 注意：neigh_indices 是 gt 到 pred 的映射，需过滤无效点
    gt_nn_in_pred = pred[neigh_indices[valid_mask]]  # (M, 2)
    dist_gt2pred = torch.norm(gt_valid - gt_nn_in_pred, dim=1)  # (M,)
    loss += (dist_gt2pred  **  2).sum()

    # 4. 归一化和加权
    loss = loss_weight * loss / (2 * valid_mask.sum().float())
    return loss


import torch


def compute_track_loss_torch(
        pred: torch.Tensor,  # (N, 3) tensor [x, y, z]
        gt: torch.Tensor,  # (N, 3) tensor [x, y, z]
        ee_mask: torch.Tensor,  # (N,) tensor (0/1 mask)
        loss_weight: float,  # scalar weight
):
    # 1. 过滤无效点 (ee_mask == 0)
    valid_mask = (ee_mask == 0)
    pred_valid = pred[valid_mask]  # (M, 3)
    gt_valid = gt[valid_mask]  # (M, 3)

    # 2. 计算各坐标分量的绝对误差
    diff = torch.abs(pred_valid - gt_valid)  # (M, 3)

    # 3. 平滑L1损失（Modified Smooth L1）
    #    if |diff| < 1.0: 0.5 * diff^2 else |diff| - 0.5
    loss_terms = torch.where(diff < 1.0, 0.5 * diff  ** 2, diff - 0.5)  # (M, 3)

    # 4. 对x/y/z三个维度求和
    loss_per_point = loss_terms.sum(dim=1)  # (M,)

    # 5. 归一化因子 (num_valid * 3)
    num_valid = valid_mask.sum().float()
    average_factor = num_valid * 3.0

    # 6. 避免除以零 (如果所有点都被过滤)
    if num_valid == 0:
        return torch.tensor(0.0, device=pred.device)

    # 7. 总损失
    loss = loss_weight * loss_per_point.sum() / average_factor

    return loss

