import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, top_k_accuracy_score
import json
import os
from collections import defaultdict


# ------------------------------
# 1. Zero-Shot Classification (§4.2.1)
# ------------------------------
def zero_shot_classification(model, dataloader, device, class_names, domain_prompts=None, save_path=None):
    """
    Args:
        domain_prompts: Dict of domain-specific prompts (e.g., {"MedMNIST2D": lambda cls: "..."}
    Returns:
        metrics: Top-1/Top-5 accuracy (§4.2.1)
    """
    model.eval()
    all_preds = []
    all_labels = []

    # Default prompt (§4.2.1: "a photo of a {class}")
    prompt_fn = lambda cls: f"a photo of a {cls}"
    if domain_prompts is not None:
        prompt_fn = domain_prompts

    # Precompute text embeddings (batch to avoid OOM)
    logger.info("Precomputing text embeddings for zero-shot classification...")
    text_embeds = []
    batch_size = 32
    with torch.no_grad():
        for i in range(0, len(class_names), batch_size):
            batch_cls = class_names[i:i + batch_size]
            batch_prompts = [prompt_fn(cls) for cls in batch_cls]
            # Extract text embeddings (model.forward with images=None)
            _, txt_emb = model(images=None, texts=batch_prompts)
            text_embeds.append(txt_emb)
        text_embeds = torch.cat(text_embeds, dim=0)  # [num_classes, d_model]

    # Evaluate images
    logger.info("Evaluating zero-shot classification...")
    with torch.no_grad():
        for img_tensor, labels in tqdm(dataloader, desc="Zero-Shot Eval"):
            img_tensor = img_tensor.to(device)
            labels = labels.cpu().numpy()

            # Extract image embeddings (model.forward with texts=None)
            img_emb, _ = model(images=img_tensor, texts=None)

            # Compute similarity (§4.2.1: logits = img_emb @ text_emb.T)
            logits = torch.matmul(img_emb, text_embeds.t())
            preds = logits.argmax(dim=1).cpu().numpy()

            all_preds.extend(preds)
            all_labels.extend(labels)

    # Calculate metrics (§4.2.1)
    top1_acc = accuracy_score(all_labels, all_preds)
    top5_acc = top_k_accuracy_score(all_labels, np.array(all_preds).reshape(-1, 1), k=5)
    metrics = {
        "top1_accuracy": top1_acc,
        "top5_accuracy": top5_acc,
        "num_samples": len(all_preds)
    }

    # Log and save
    logger.info(f"Zero-Shot Classification Results:")
    logger.info(f"  Top-1 Accuracy: {top1_acc:.4f}")
    logger.info(f"  Top-5 Accuracy: {top5_acc:.4f}")

    if save_path is not None:
        with open(os.path.join(save_path, "zero_shot_metrics.json"), "w") as f:
            json.dump(metrics, f, indent=4)
    return metrics


# ------------------------------
# 2. Cross-Modal Retrieval (§4.2.2: R@1/R@5/R@10 + MRR)
# ------------------------------
def cross_modal_retrieval(model, dataloader, device, save_path=None):
    model.eval()
    # Store embeddings and metadata
    img_emb_list = []
    txt_emb_list = []
    img_id_list = []

    # Extract embeddings
    logger.info("Extracting image-text embeddings for cross-modal retrieval...")
    with torch.no_grad():
        for img_tensor, captions, img_ids in tqdm(dataloader, desc="Embedding Extraction"):
            img_tensor = img_tensor.to(device)

            # Get embeddings
            img_emb, txt_emb = model(images=img_tensor, texts=captions)

            img_emb_list.append(img_emb.cpu())
            txt_emb_list.append(txt_emb.cpu())
            img_id_list.extend(img_ids)

    # Concatenate embeddings
    img_emb = torch.cat(img_emb_list, dim=0)  # [N_total, d_model]
    txt_emb = torch.cat(txt_emb_list, dim=0)  # [N_total, d_model]

    # Group text embeddings by image ID (Karpathy 1K split)
    img_id_to_txt_emb = defaultdict(list)
    for img_id, te in zip(img_id_list, txt_emb):
        img_id_to_txt_emb[img_id].append(te)

    # Average text embeddings per image
    unique_img_ids = list(img_id_to_txt_emb.keys())
    avg_txt_emb = torch.stack([
        torch.mean(torch.stack(emb_list), dim=0) for emb_list in img_id_to_txt_emb.values()
    ])  # [N_unique, d_model]

    # Similarity matrix (image x text)
    sim_mat = torch.matmul(img_emb, avg_txt_emb.t())  # [N_total, N_unique]
    img_id_to_idx = {img_id: i for i, img_id in enumerate(unique_img_ids)}
    target_idx = [img_id_to_idx[img_id] for img_id in img_id_list]  # [N_total]

    # Calculate ranks (1-based)
    def calculate_ranks(sim_matrix, targets):
        ranks = []
        for i in range(sim_matrix.shape[0]):
            # Sort indices by similarity (descending)
            sorted_idx = torch.argsort(sim_matrix[i], descending=True)
            # Find rank of target
            rank = (sorted_idx == targets[i]).nonzero().item() + 1
            ranks.append(rank)
        return ranks

    # Image→Text retrieval
    img2txt_ranks = calculate_ranks(sim_mat, target_idx)
    # Text→Image retrieval (transpose similarity matrix)
    txt2img_ranks = calculate_ranks(sim_mat.t(), list(range(len(unique_img_ids))))

    # Compute metrics (§4.2.2: R@1/R@5/R@10 + MRR)
    def compute_retrieval_metrics(ranks):
        ranks_np = np.array(ranks)
        r1 = np.mean(ranks_np <= 1)
        r5 = np.mean(ranks_np <= 5)
        r10 = np.mean(ranks_np <= 10)
        mrr = np.mean(1 / ranks_np)
        return {
            "R@1": r1,
            "R@5": r5,
            "R@10": r10,
            "MRR": mrr,
            "Median_Rank": np.median(ranks_np)
        }

    metrics = {
        "image_to_text": compute_retrieval_metrics(img2txt_ranks),
        "text_to_image": compute_retrieval_metrics(txt2img_ranks)
    }

    # Log and save
    logger.info("Cross-Modal Retrieval Results (§4.2.2):")
    logger.info("Image→Text:")
    for k, v in metrics["image_to_text"].items():
        logger.info(f"  {k}: {v:.4f}")
    logger.info("Text→Image:")
    for k, v in metrics["text_to_image"].items():
        logger.info(f"  {k}: {v:.4f}")

    if save_path is not None:
        with open(os.path.join(save_path, "cross_modal_metrics.json"), "w") as f:
            json.dump(metrics, f, indent=4)
    return metrics


# ------------------------------
# 3. Fine-Grained Alignment (§4.2.3: ADE20K Mask Transfer)
# ------------------------------
def fine_grained_alignment(model, dataloader, device, save_path=None):
    """Implements zero-shot text-driven mask transfer (§4.2.3)"""
    model.eval()
    all_miou = []
    all_obj_alignment = []

    logger.info("Evaluating fine-grained alignment on ADE20K...")
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Fine-Grained Eval"):
            img_tensor = batch['image'].to(device)
            mask_tensor = batch['mask'].to(device)  # [B, 1, H, W]
            caption = batch['caption']
            object_texts = batch['object_texts']
            B = img_tensor.shape[0]

            # Step 1: Extract image patch embeddings
            # Get pixel-level refined features (modify model to return intermediate features)
            _, _, _, pixel_feat = model.forward_pixel_alignment(img_tensor, caption)  # [B, H*W, d]
            H, W = img_tensor.shape[-2], img_tensor.shape[-1]
            patch_feat = pixel_feat.reshape(B, H, W, -1).permute(0, 3, 1, 2)  # [B, d, H, W]

            # Step 2: Extract object text embeddings
            obj_emb_list = []
            for obj_text in object_texts:
                if not obj_text:
                    obj_emb_list.append(torch.zeros(1, model.d_model, device=device))