# llava/model/relational_embedding.py

from typing import Tuple
import torch
import torch.nn as nn

class RelationalPositionalEncoding(nn.Module):
    
    def __init__(self, embedding_dim: int, num_bins_per_axis: int = 5, num_dims: int = 3):
        
        super().__init__()
        self.num_bins_per_axis = num_bins_per_axis
        self.num_dims = num_dims
        self.total_bins = num_bins_per_axis ** self.num_dims
        self.embedding_dim = embedding_dim
        self.relational_embedding = nn.Embedding(self.total_bins, self.embedding_dim)
        boundaries = torch.linspace(-1.0, 1.0, num_bins_per_axis - 1)
        self.register_buffer('boundaries', boundaries)

    def _calculate_relational_ids(self, coords: torch.Tensor) -> torch.Tensor:
        assert coords.shape[-1] == self.num_dims, \
            f"输入坐标有 {coords.shape[-1]} 个维度, 但模块被配置为处理 {self.num_dims} 个维度。"

        
        coords = coords.contiguous()
        delta_coords = coords.unsqueeze(2) - coords.unsqueeze(1)
        delta_coords = delta_coords.contiguous()



        all_bins = []
        for i in range(self.num_dims):
            boundaries = self.boundaries.to(delta_coords.device)
            bins = torch.bucketize(delta_coords[..., i].contiguous(), boundaries, right=False)
            bins = torch.clamp(bins, 0, self.num_bins_per_axis - 1)
            all_bins.append(bins)

        relational_ids = torch.zeros_like(all_bins[0], dtype=torch.long)
        for i in range(self.num_dims):
      
            dim_index = self.num_dims - 1 - i
            power = i
            contribution = all_bins[dim_index] * (self.num_bins_per_axis ** power)
            relational_ids += contribution

        max_id = self.num_bins_per_axis ** self.num_dims - 1
        relational_ids = torch.clamp(relational_ids, 0, max_id)

        return relational_ids.long()

    def forward(self, coords: torch.Tensor) -> torch.Tensor:
        relational_ids = self._calculate_relational_ids(coords)
        device = self.relational_embedding.weight.device

    
        relational_ids = relational_ids.to(device)



  
        bias_vectors = self.relational_embedding(relational_ids)

        result = bias_vectors.permute(0, 3, 1, 2)

        return result