import torch
import torchmetrics
from torchmetrics import RetrievalMRR, RetrievalHitRate, RetrievalMAP

def MRR(similarity):
    device = similarity.device
    size1, size2 = similarity.shape
    l = size1 * size2

    target = torch.eye(size1, size2, dtype=torch.bool).to(device)
    indexes = torch.tensor([[i] * size1 for i in range(size2)]).to(device)

    mrr = RetrievalMRR()
    mrr_score = mrr(similarity.reshape(l), target.reshape(l), indexes=indexes.reshape(l))

    return mrr_score

def Retrieval_metrics(similarity):
    device = similarity.device
    size1, size2 = similarity.shape
    l = size1 * size2

    target = torch.eye(size1, size2, dtype=torch.bool).to(device)
    indexes = torch.tensor([[i] * size1 for i in range(size2)]).to(device)

    metrics = {}

    mrr = RetrievalMRR()
    metrics['mrr'] = mrr(similarity.reshape(l), target.reshape(l), indexes=indexes.reshape(l)).item()
    top_1 = RetrievalHitRate(k=1)
    metrics['top_1'] = top_1(similarity.reshape(l), target.reshape(l), indexes=indexes.reshape(l)).item()
    top_5 = RetrievalHitRate(k=5)
    metrics['top_5'] = top_5(similarity.reshape(l), target.reshape(l), indexes=indexes.reshape(l)).item()
    top_10 = RetrievalHitRate(k=10)
    metrics['top_10'] = top_10(similarity.reshape(l), target.reshape(l), indexes=indexes.reshape(l)).item()

    return metrics

def MiniRetrieval_metrics(similarity):
    device = similarity.device
    size1, size2 = similarity.shape
    l = size1 * size2

    target = torch.eye(size1, size2, dtype=torch.bool).to(device)
    indexes = torch.tensor([[i] * size1 for i in range(size2)]).to(device)

    mrr = RetrievalMRR()
    mrr_score = mrr(similarity.reshape(l), target.reshape(l), indexes=indexes.reshape(l))
    top_1 = RetrievalHitRate(k=1)
    top_1_score = top_1(similarity.reshape(l), target.reshape(l), indexes=indexes.reshape(l))

    return mrr_score.item(), top_1_score.item()