import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import logging

# class RelativePositionLoss(nn.Module):
#     """
#     Calculates the loss for the relative position prediction task.

#     This module wraps the CrossEntropyLoss and handles the conversion of
#     real relative positions (e.g., -k, ..., +k) to class indices (e.g., 0, ..., 2k).
#     """

#     def __init__(self, max_relative_position: int, ignore_index: int = -100):
#         """
#         Args:
#             max_relative_position (int): The maximum absolute relative distance `k`.
#             ignore_index (int): Specifies a target value that is ignored and does not
#                                 contribute to the input gradient.
#         """
#         super().__init__()
#         self.max_rel_dist = max_relative_position
#         self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_index)

#     def forward(self, predicted_logits: torch.Tensor, true_relative_positions: torch.Tensor) -> torch.Tensor:
#         """
#         Args:
#             predicted_logits (torch.Tensor): The output from the PositionPredictionHead.
#                                              Shape: (B*N, num_classes) or (B, N, num_classes)
#             true_relative_positions (torch.Tensor): The ground truth relative positions.
#                                                     Shape: (B*N) or (B, N)

#         Returns:
#             torch.Tensor: The calculated cross-entropy loss (a scalar).
#         """
#         # Reshape inputs to (BatchSize, NumClasses) and (BatchSize)
#         num_classes = predicted_logits.shape[-1]
#         logits_flat = predicted_logits.reshape(-1, num_classes)

#         true_pos_flat = true_relative_positions.reshape(-1)

#         # --- Core Logic ---
#         # Convert true relative positions from [-k, k] to target indices [0, 2k]
#         # Any position outside this range will be mapped to an ignored index if we want,
#         # but the `process_spatial_feature_all_perspectives` ensures they are within the window.
#         target_indices = true_pos_flat + self.max_rel_dist

#         # Ensure target indices are of type long for CrossEntropyLoss
#         target_indices = target_indices.long()

#         # Calculate loss
#         loss = self.criterion(logits_flat, target_indices)

#         return loss


class RelativePositionLoss(nn.Module):
    def __init__(self, max_relative_position: int):
        super().__init__()
        # 可以选择 L1 Loss (Mean Absolute Error) 或 MSE Loss (Mean Squared Error)
        # L1对异常值更鲁棒，MSE对大误差惩罚更重
        self.criterion = nn.MSELoss()
        self.max_rel_dist = max_relative_position

    def forward(self, predicted_values: torch.Tensor, true_relative_positions: torch.Tensor) -> torch.Tensor:
        # predicted_values shape: (B, N, 1)
        # true_relative_positions shape: (B, N)

        # logging.info(f"predicted_values:{predicted_values.shape}")
        # logging.info(f"true_relative_positions:{true_relative_positions.shape}")

        predicted_indices = torch.argmax(predicted_values, dim=-1)

        predicted_positions = predicted_indices - self.max_rel_dist

        # logging.info(f"predicted_positions:{predicted_positions.shape}")
        # logging.info(f"true_relative_positions:{true_relative_positions.shape}")

        # Now, both tensors have the same shape and represent numerical values.
        # Let's flatten them for the loss function.
        predicted_flat = predicted_positions.reshape(-1).float()
        true_pos_flat = true_relative_positions.reshape(-1).float()

        # logging.info(f"Shape after conversion - predicted_flat: {predicted_flat.shape}")
        # logging.info(f"Shape after conversion - true_pos_flat: {true_pos_flat.shape}")

        # Both shapes will now be [128], and the loss can be computed correctly.
        loss = self.criterion(predicted_flat, true_pos_flat)

        return loss
