import torch
import torch.nn.functional as F

from .utils import qr_based_pseudoinverse, conjugate_transpose_inverse


def compute_attn_wts(x: torch.Tensor, dst: torch.Tensor) -> torch.Tensor:
    """
    Computes the attention weights for the input tensor x given the destination tensor dst.
    """
    # Normalize query and key tensors
    # Q = F.normalize(x, p=2, dim=-1).float()  # Q: [B, N, C]
    # K = F.normalize(dst, p=2, dim=-1).float()  # K: [B, num_dst, C]

    Q = x
    K = dst

    # Compute attention scores and normalize
    A = K @ Q.transpose(-1, -2)  # attn_weights: [B, num_dst, N]

    _, max_indices = torch.max(A, dim=1)
    A.zero_()
    A.scatter_(dim=1, index=max_indices.unsqueeze(1), value=1)
    count_per_dst = A.sum(dim=-1, keepdim=True)  # count_per_dst: [B, num_dst, 1]
    avg_A = A / count_per_dst  # Normalize attention scores: [B, num_dst, N]

    # return avg_A.half(), A.half().transpose(-1, -2)
    return avg_A, A.transpose(-1, -2)
