import torch

def a2t_metric(audio_embed: torch.Tensor, text_embed: torch.Tensor):
    """
    Evaluate audio-to-text retrieval metrics assuming 1-to-1 matching.
    Embeddings must be L2-normalized already.

    Args:
        audio_embed (torch.Tensor): shape (N, D), normalized
        text_embed (torch.Tensor): shape (N, D), normalized

    Returns:
        r1, r5, r10, medr, meanr, mAP10
    """
    sim = audio_embed @ text_embed.T  # (N, N), cosine sim == dot product since normalized

    N = sim.size(0)
    sorted_indices = sim.argsort(dim=1, descending=True)  # top-k indices per row
    ground_truth = torch.arange(N, device=sim.device).unsqueeze(1)  # (N, 1)
    ranks = (sorted_indices == ground_truth).nonzero(as_tuple=False)[:, 1]  # (N,)

    r1 = (ranks < 1).float().mean().item() * 100
    r5 = (ranks < 5).float().mean().item() * 100
    r10 = (ranks < 10).float().mean().item() * 100
    mAP10 = (1 / (ranks[ranks < 10].float() + 1)).mean().item() * 100 if (ranks < 10).any() else 0.0
    medr = ranks.median().item() + 1
    meanr = ranks.float().mean().item() + 1

    return r1, r5, r10, medr, meanr, mAP10


def t2a_metric(text_embed: torch.Tensor, audio_embed: torch.Tensor):
    """
    Evaluate text-to-audio retrieval metrics assuming 1-to-1 matching.
    Embeddings must be L2-normalized already.

    Args:
        text_embed (torch.Tensor): shape (N, D), normalized
        audio_embed (torch.Tensor): shape (N, D), normalized

    Returns:
        r1, r5, r10, medr, meanr, mAP10
    """
    sim = text_embed @ audio_embed.T  # (N, N), cosine similarity

    N = sim.size(0)
    sorted_indices = sim.argsort(dim=1, descending=True)
    ground_truth = torch.arange(N, device=sim.device).unsqueeze(1)
    ranks = (sorted_indices == ground_truth).nonzero(as_tuple=False)[:, 1]

    r1 = (ranks < 1).float().mean().item() * 100
    r5 = (ranks < 5).float().mean().item() * 100
    r10 = (ranks < 10).float().mean().item() * 100
    mAP10 = (1 / (ranks[ranks < 10].float() + 1)).mean().item() * 100 if (ranks < 10).any() else 0.0
    medr = ranks.median().item() + 1
    meanr = ranks.float().mean().item() + 1

    return r1, r5, r10, medr, meanr, mAP10


def multi_a2t(audio_embed: torch.Tensor, text_embed: torch.Tensor, num_repeats: int):
    num_audios = audio_embed.size(0) // num_repeats

    ranks = torch.zeros(num_audios, device=audio_embed.device)
    AP10 = torch.zeros(num_audios, device=audio_embed.device)
    
    for index in range(num_audios):
        audio = audio_embed[num_repeats * index]

        sim = torch.matmul(audio, text_embed.T)  # (N, )
        sorted_indices = sim.argsort(descending=True)

        rank = float('inf')
        inds_map = []
        for i in range(num_repeats * index, num_repeats * index + num_repeats):
            tmp = (sorted_indices == i).nonzero(as_tuple=False).item()
            if tmp < rank:
                rank = tmp
            if tmp < 10:
                inds_map.append(tmp + 1)

        if len(inds_map) > 0:
            inds_map = torch.tensor(inds_map, device=audio_embed.device)
            AP10[index] = torch.sum(torch.arange(1, len(inds_map) + 1, device=audio_embed.device) / inds_map) / num_repeats
        else:
            AP10[index] = 0.0

        ranks[index] = rank
       
    r1 = (ranks < 1).float().mean().item() * 100
    r5 = (ranks < 5).float().mean().item() * 100
    r10 = (ranks < 10).float().mean().item() * 100
    mAP10 = AP10.mean().item() * 100
    medr = ranks.median().item() + 1
    meanr = ranks.mean().item() + 1
 
    return r1, r5, r10, medr, meanr, mAP10
