import torch
import torch.nn.functional as F

import wandb

if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

def evaluate_retrieval(
    img_autoencoder,
    text_autoencoder,
    test_loader,
):
    """
    - Full image→text: each image vs all 2N captions
    - Full text→image: each pos caption vs all N images
    - Easy: per image, compare only its pos vs its hard neg
    - Medium-hard: per image, compare pos, hard neg, and 3 random *other* negs
    """
    device = next(img_autoencoder.parameters()).device
    img_autoencoder.eval()
    text_autoencoder.eval()

    img_Z, pos_Z, neg_Z = [], [], []

    with torch.no_grad():
        for img_embs, pos_txt_embs, neg_txt_embs in test_loader:
            img_embs = img_embs.to(device)
            pos_txt_embs = pos_txt_embs.to(device)
            neg_txt_embs = neg_txt_embs.to(device)

            z_img, _ = img_autoencoder(img_embs)
            z_pos, _ = text_autoencoder(pos_txt_embs)
            z_neg, _ = text_autoencoder(neg_txt_embs)

            img_Z.append(z_img)
            pos_Z.append(z_pos)
            neg_Z.append(z_neg)

    # concatenate & normalize
    img_Z = F.normalize(torch.cat(img_Z, dim=0), dim=1)  # [N, d]
    pos_Z = F.normalize(torch.cat(pos_Z, dim=0), dim=1)  # [N, d]
    neg_Z = F.normalize(torch.cat(neg_Z, dim=0), dim=1)  # [N, d]

    N = img_Z.size(0)
    device = img_Z.device

    # --- Full image-to-text recall@1 over 2N captions ---
    txt_bank = torch.cat([pos_Z, neg_Z], dim=0)          # [2N, d]
    sim_it   = img_Z @ txt_bank.T                        # [N, 2N]
    pred_it  = sim_it.argmax(dim=1)
    correct_it = (pred_it == torch.arange(N, device=device)).sum().item()
    img2txt_acc = correct_it / N

    # --- Full text-to-image recall@1 over N images ---
    sim_ti   = pos_Z @ img_Z.T                           # [N, N]
    pred_ti  = sim_ti.argmax(dim=1)
    correct_ti = (pred_ti == torch.arange(N, device=device)).sum().item()
    txt2img_acc = correct_ti / N

    # --- Easy recall: pos vs hard neg ---
    sim_pos = (img_Z * pos_Z).sum(dim=1)  # [N]
    sim_neg = (img_Z * neg_Z).sum(dim=1)  # [N]
    easy_acc = (sim_pos > sim_neg).float().mean().item()

    # --- Medium-hard recall: pos, hard neg, +3 random other negs ---
    medium_correct = 0
    all_idxs = torch.arange(N, device=device)
    for i in range(N):
        # pick 3 random negatives ≠ i
        other = all_idxs[all_idxs != i]
        rand_three = other[torch.randperm(N - 1, device=device)[:3]]

        # gather the five candidates
        # order: [pos(i), neg(i), neg(j1), neg(j2), neg(j3)]
        candidates = torch.vstack([
            pos_Z[i],
            neg_Z[i],
            neg_Z[rand_three[0]],
            neg_Z[rand_three[1]],
            neg_Z[rand_three[2]],
        ])  # [5, d]

        sims = (img_Z[i].unsqueeze(0) * candidates).sum(dim=1)  # [5]
        if sims.argmax() == 0:  # index 0 == the positive
            medium_correct += 1

    medium_acc = medium_correct / N

    # print & log
    print(f"Easy (pos vs neg):      {easy_acc*100:.2f}%")
    print(f"Medium-hard (1+1+3 neg):{medium_acc*100:.2f}%")

    wandb.log({
        "Easy_R@1_pos_vs_neg":   easy_acc,
        "Medium_R@1_5_options":  medium_acc,
    })

    return img2txt_acc, txt2img_acc, easy_acc, medium_acc
