import torch
import torch.nn.functional as F


def euclidean_distance_matrix(
    matrix1: torch.Tensor, matrix2: torch.Tensor
) -> torch.Tensor:
    """
    Compute the pairwise Euclidean distance matrix between two sets of vectors.

    Uses the identity:
        ||x - y||^2 = ||x||^2 - 2 x·y + ||y||^2

    Args:
        matrix1 (Tensor): A tensor of shape (N1, D), where each row is a D-dimensional vector.
        matrix2 (Tensor): A tensor of shape (N2, D), where each row is a D-dimensional vector.

    Returns:
        Tensor: A distance matrix of shape (N1, N2), where entry (i, j) is
                ||matrix1[i] - matrix2[j]||.
    """
    # Ensure both inputs have the same feature dimension
    assert (
        matrix1.shape[1] == matrix2.shape[1]
    ), "Input tensors must have the same number of columns"

    # Compute -2 * (matrix1 @ matrix2^T)
    # This yields a matrix where entry (i, j) = -2 * x_i · y_j
    d1 = -2 * torch.mm(matrix1, matrix2.T)  # :contentReference[oaicite:3]{index=3}

    # Compute squared norms of each row in matrix1: shape (N1, 1)
    d2 = torch.sum(matrix1.pow(2), dim=1, keepdim=True)

    # Compute squared norms of each row in matrix2: shape (N2,)
    d3 = torch.sum(matrix2.pow(2), dim=1)

    # Sum and take square root to get the Euclidean distances, with broadcasting
    dists = torch.sqrt(d1 + d2 + d3)

    return dists


# def euclidean_distance_matrix(a, b):
#     return torch.cdist(a, b)


def cal_mmdist_rprecision(
    A_embeddings: torch.Tensor,
    B_embeddings: torch.Tensor,
    r_precision_batch: int = 256,
    top_k: int = 3,
    reduction: bool = False,
) -> (torch.Tensor, torch.Tensor):
    """
    Calculate the sum of matching distances and R-precision statistics over embedding batches.

    For each batch slice of size `r_precision_batch`, embeddings are L2-normalized,
    pairwise distances are computed, and:
      1. The trace of the distance matrix (sum of self-distances) is accumulated.
      2. R-precision is calculated up to `top_k` for each query.

    Args:
        A_embeddings (Tensor): Tensor of shape (N, D), e.g., text embeddings.
        B_embeddings (Tensor): Tensor of shape (N, D), e.g., motion embeddings.
        r_precision_batch (int): Batch size for R-precision evaluation.
        top_k (int): Maximum rank cutoff for R-precision (computes top-1, top-2, …, top-k).
        reduction (bool): If True, average the metrics over all valid samples.

    Returns:
        mm_dist (Tensor): Scalar or averaged sum of diagonal distances across batches.
        top_k_mat (Tensor): A tensor of shape (top_k,) containing counts (or averages if
                            `reduction=True`) of queries whose ground-truth match appeared
                            within the top-1, top-2, …, top-k ranks.
    """
    num_samples = A_embeddings.shape[0]
    mm_dist = 0.0
    top_k_mat = torch.zeros((top_k,), device=A_embeddings.device)

    # Process in disjoint batches of size r_precision_batch
    for i in range(num_samples // r_precision_batch):
        # Slice the next batch
        start = i * r_precision_batch
        end = start + r_precision_batch
        group_A = F.normalize(A_embeddings[start:end], p=2, dim=1, eps=1e-12)
        group_B = F.normalize(
            B_embeddings[start:end], p=2, dim=1, eps=1e-12
        )  # :contentReference[oaicite:4]{index=4}

        # Compute pairwise Euclidean distances for this batch
        dist_mat = euclidean_distance_matrix(group_A, group_B)

        # Accumulate the trace (sum of distances between each vector and its counterpart)
        mm_dist += torch.trace(dist_mat)

        # Compute ranking matrix (indices of sorted distances)
        argsort_mat = torch.argsort(dist_mat, dim=1)

        # Update R-precision counts
        top_k_mat += calculate_top_k(argsort_mat, top_k=top_k).sum(dim=0)

    if reduction:
        valid_samples = (num_samples // r_precision_batch) * r_precision_batch
        mm_dist /= valid_samples
        top_k_mat /= valid_samples

    return mm_dist, top_k_mat


def calculate_top_k(mat, top_k):
    size = mat.shape[0]
    gt_mat = (
        torch.unsqueeze(torch.arange(size), 1).to(mat.device).repeat_interleave(size, 1)
    )
    bool_mat = mat == gt_mat
    correct_vec = False
    top_k_list = []
    for i in range(top_k):
        #         print(correct_vec, bool_mat[:, i])
        correct_vec = correct_vec | bool_mat[:, i]
        # print(correct_vec)
        top_k_list.append(correct_vec[:, None])
    top_k_mat = torch.cat(top_k_list, dim=1)
    return top_k_mat
