import os
import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

def show_topk_retrieval(
    model,
    dataloader,
    tokenizer,
    device,
    save_dir,
    topk=5,
    max_vis=20
):
    """
    Visualize medical image-report retrieval results, generating images for Top-k retrieval pairs of the first max_vis samples.
    """
    os.makedirs(save_dir, exist_ok=True)
    model.eval()
    all_image_embeds, all_text_embeds, images_raw, texts_raw = [], [], [], []

    # Extract all features
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Extracting for visualization"):
            images = batch['image'].to(device)
            input_ids = batch['input_ids'].to(device)
            attn_mask = batch['attention_mask'].to(device)
            image_embeds = model.encode_image(images)
            text_embeds = model.encode_text(input_ids, attn_mask)
            all_image_embeds.append(image_embeds.cpu())
            all_text_embeds.append(text_embeds.cpu())
            images_raw.extend(images.cpu().numpy())
            for ids in batch['input_ids']:
                # Decode text
                text = tokenizer.decode(
                    ids.tolist(),
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=True
                )
                texts_raw.append(text)
            if len(texts_raw) >= max_vis:
                break

    all_image_embeds = torch.cat(all_image_embeds, dim=0)
    all_text_embeds = torch.cat(all_text_embeds, dim=0)
    all_image_embeds = torch.nn.functional.normalize(all_image_embeds, dim=-1)
    all_text_embeds = torch.nn.functional.normalize(all_text_embeds, dim=-1)
    sims = all_image_embeds @ all_text_embeds.T  # [N, N]
    sims = sims.cpu().numpy()

    for idx in range(min(max_vis, sims.shape[0])):
        sim_row = sims[idx]
        topk_idx = np.argsort(-sim_row)[:topk]
        fig, axs = plt.subplots(1, topk+1, figsize=(3*(topk+1), 3))
        # Original image
        img = images_raw[idx].transpose(1, 2, 0)  # C,H,W -> H,W,C
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)
        axs[0].imshow(img)
        axs[0].axis('off')
        axs[0].set_title("Query")
        axs[0].text(0, -20, texts_raw[idx][:60], fontsize=8, wrap=True)
        # Top-k retrieval
        for j, k_idx in enumerate(topk_idx):
            img_k = images_raw[k_idx].transpose(1, 2, 0)
            img_k = (img_k - img_k.min()) / (img_k.max() - img_k.min() + 1e-8)
            axs[j+1].imshow(img_k)
            axs[j+1].axis('off')
            axs[j+1].set_title(f"Top{k_idx+1} ({sim_row[k_idx]:.2f})")
            axs[j+1].text(0, -20, texts_raw[k_idx][:60], fontsize=8, wrap=True)
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f"retrieval_{idx}.png"))
        plt.close()