import torch
import numpy as np

@torch.no_grad()
def compute_retrieval(img_feats, txt_feats, topk=(1, 5, 10)):
    img_feats = torch.nn.functional.normalize(img_feats, dim=-1)
    txt_feats = torch.nn.functional.normalize(txt_feats, dim=-1)
    sim = img_feats @ txt_feats.T  # [N, N]

    res = {}
    # image2text
    ranks = sim.argsort(dim=-1, descending=True)
    gt = torch.arange(len(img_feats))
    for k in topk:
        res[f'image2text_R@{k}'] = (ranks[:, :k] == gt.unsqueeze(1)).any(dim=1).float().mean().item()
    # text2image
    ranks2 = sim.t().argsort(dim=-1, descending=True)
    for k in topk:
        res[f'text2image_R@{k}'] = (ranks2[:, :k] == gt.unsqueeze(1)).any(dim=1).float().mean().item()
    return res

@torch.no_grad()
def compute_bilingual_consistency(txt_en_feats, txt_zh_feats):
    # Calculate the mean cosine similarity of all English and Chinese text embeddings
    txt_en_feats = torch.nn.functional.normalize(txt_en_feats, dim=-1)
    txt_zh_feats = torch.nn.functional.normalize(txt_zh_feats, dim=-1)
    sim = (txt_en_feats * txt_zh_feats).sum(dim=-1)
    return {'cosine_similarity_mean': sim.mean().item()}