from tqdm import tqdm
import torch
import torch.nn.functional as F

def retrieve_features(model, dataloader, lesion_db, patient_db, device):
    model.eval()

    image_feats, retrieved_lesion_feats, retrieved_patient_feats, all_targets = [], [], [], []
    top1_sim_lesion_list, top1_sim_patient_list = [], []

    # Standardize per-dimension, then L2-normalize
    lesion_db_std  = lesion_db / (lesion_db.std(dim=0, keepdim=True) + 1e-6)
    patient_db_std = patient_db / (patient_db.std(dim=0, keepdim=True) + 1e-6)
    lesion_db_norm = F.normalize(lesion_db_std, dim=1, p=2)
    patient_db_norm = F.normalize(patient_db_std, dim=1, p=2)

    kl = 1      # lesion top-K
    Kp = 1      # patient top-K
    temp_l = 20 # temperature for lesion weights
    temp_p = 20 # temperature for patient weights

    for batch in tqdm(dataloader, desc="Retrieving"):
        batch = [b.to(device, non_blocking=True) for b in batch]
        images = batch[-2]
        targets = batch[-1]

        with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):
            # 1) Image features
            image_feat = model.vit(images)                    # (B, D)
            image_feat_norm = F.normalize(image_feat, dim=1)  # (B, D)

            # 2) Lesion retrieval
            sim_lesion = torch.mm(image_feat_norm, lesion_db_norm.t())           # (B, M1)
            kL = min(kl, sim_lesion.size(1))
            sim_topk_l, idx_topk_l = torch.topk(sim_lesion, k=kL, dim=1)         # (B, kL)
            weights_l = torch.softmax(sim_topk_l * temp_l, dim=1).unsqueeze(-1)  # (B,kL,1)
            lesions_k = lesion_db_std[idx_topk_l]                                # (B,kL,D)
            top1_lesion_feat = (weights_l * lesions_k).sum(dim=1)                # (B,D)
            top1_sim_lesion, _ = sim_lesion.max(dim=1)                           # (B,)

            # 3) Lesion + Image projection
            lesion_fused = model.lesion_fc(torch.cat([image_feat, top1_lesion_feat], dim=1))  # (B, D)
            lesion_fused = F.normalize(lesion_fused, dim=1)

            # 4) Patient retrieval
            sim_patient = torch.mm(lesion_fused, patient_db_norm.t())           # (B,M2)
            kP = min(Kp, sim_patient.size(1))
            sim_topk_p, idx_topk_p = torch.topk(sim_patient, k=kP, dim=1)       # (B,kP)
            weights_p = torch.softmax(sim_topk_p * temp_p, dim=1).unsqueeze(-1) # (B,kP,1)
            patients_k = patient_db_std[idx_topk_p]                             # (B,kP,D)
            top1_patient_feat = (weights_p * patients_k).sum(dim=1)             # (B,D)

            top1_sim_patient = sim_topk_p[:, 0]

        # Store outputs
        image_feats.append(image_feat.cpu())
        retrieved_lesion_feats.append(top1_lesion_feat.cpu())
        retrieved_patient_feats.append(top1_patient_feat.cpu())
        all_targets.append(targets.cpu())
        
        top1_sim_lesion_list.append(top1_sim_lesion.cpu())
        top1_sim_patient_list.append(top1_sim_patient.cpu())
    
    top1_sim_lesion_all = torch.cat(top1_sim_lesion_list)
    top1_sim_patient_all = torch.cat(top1_sim_patient_list)

    # Print stats
    print("\nRetrieval Similarity Statistics:")
    print(f"- Lesion  Top-1 Cosine Similarity: mean = {top1_sim_lesion_all.mean():.4f}, std = {top1_sim_lesion_all.std():.4f}")
    print(f"- Patient Top-1 Cosine Similarity: mean = {top1_sim_patient_all.mean():.4f}, std = {top1_sim_patient_all.std():.4f}")

    return (
        torch.cat(image_feats),
        torch.cat(retrieved_lesion_feats),
        torch.cat(retrieved_patient_feats),
        torch.cat(all_targets)
    )
