import os
import sys
from typing import Dict, List

import clip
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, wasserstein_distance, wilcoxon, ttest_rel
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode


# ---------------------------
# Utility functions
# ---------------------------
def cosine_sim(a, b):
    return F.cosine_similarity(a, b).item()

def max_delta_prob(logits1, logits2):
    p1, p2 = F.softmax(logits1, dim=-1), F.softmax(logits2, dim=-1)
    return (p1 - p2).abs().max().item()

def mask_iou(mask1, mask2):
    inter = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    return inter / max(union, 1)

def mask_dice(mask1, mask2):
    inter = np.logical_and(mask1, mask2).sum()
    return 2*inter / (mask1.sum() + mask2.sum() + 1e-8)

def _topk_binary_mask(sim_map_2d: np.ndarray, topk_ratio: float) -> np.ndarray:
    flat = sim_map_2d.reshape(-1)
    k = max(1, int(topk_ratio * flat.size))
    thr = np.partition(flat, -k)[-k]
    return (sim_map_2d >= thr).astype(np.uint8)

def evaluate_predictions(z_clean, z_adv, text_features, all_texts):
    """Return predicted labels and metrics (cosine sim, prob delta)."""
    zc, za = z_clean[:,0,:], z_adv[:,0,:]   # CLS token
    zc = zc / (zc.norm(dim=-1, keepdim=True) + 1e-10)
    za = za / (za.norm(dim=-1, keepdim=True) + 1e-10)

    logits_c = zc @ text_features.T
    logits_a = za @ text_features.T
    pred_c, pred_a = logits_c.argmax(1).item(), logits_a.argmax(1).item()

    cos_sim = F.cosine_similarity(zc, za).item()
    delta = (torch.softmax(logits_a, dim=1) - torch.softmax(logits_c, dim=1)).abs().max().item()

    return {
        "clean_pred": all_texts[pred_c],
        "adv_pred": all_texts[pred_a],
        "cosine_sim": cos_sim,
        "max_delta_prob": delta
    }
def evaluate_image_pair(model, image_clean, image_adv, text_features, label_idx,
                        sim_map_clean, sim_map_adv, topk_ratio=0.3, drop_conf_threshold=0.03):
    """
    Compute all metrics for a single (clean, adv) pair and one target label
    """
    # CLS embeddings
    z_clean = model.encode_image(image_clean)[:,0,:]
    z_clean = z_clean / z_clean.norm(dim=-1, keepdim=True)
    z_adv   = model.encode_image(image_adv)[:,0,:]
    z_adv   = z_adv / z_adv.norm(dim=-1, keepdim=True)

    cos_cls = cosine_sim(z_clean, z_adv)
    delta_prob = max_delta_prob(z_clean @ text_features.T, z_adv @ text_features.T)

    # Similarity maps -> numpy
    sim_clean_np = sim_map_clean[0,:,:,label_idx].detach().cpu().numpy()
    sim_adv_np   = sim_map_adv[0,:,:,label_idx].detach().cpu().numpy()

    # Top-k masks
    mask_clean = _topk_binary_mask(sim_clean_np, topk_ratio)
    mask_adv   = _topk_binary_mask(sim_adv_np, topk_ratio)

    # IoU
    iou = mask_iou(mask_clean, mask_adv)

    # Faithfulness (confidence drop)
    text_feat = text_features[label_idx].unsqueeze(0).to(image_clean.device)
    sim_clean_orig = (z_clean @ text_feat.T).item()
    sim_adv_orig   = (z_adv @ text_feat.T).item()

    # Mask patches by zeroing out
    H, W = image_clean.shape[2:]
    mask_clean_up = cv2.resize(mask_clean.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST)
    mask_adv_up   = cv2.resize(mask_adv.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST)

    def apply_mask(img, mask):
        mask3 = np.repeat(mask[:,:,None], 3, axis=2).astype(np.float32)
        img_np = img.squeeze(0).permute(1,2,0).detach().cpu().numpy()
        img_np = (img_np - img_np.min())/(img_np.max()-img_np.min()+1e-8)
        img_masked = img_np * (1-mask3)
        return torch.from_numpy(img_masked).permute(2,0,1).unsqueeze(0).to(img.device).type_as(img)

    img_clean_masked = apply_mask(image_clean, mask_clean_up)
    img_adv_masked   = apply_mask(image_adv, mask_adv_up)

    sim_clean_masked = (model.encode_image(img_clean_masked)[:,0,:] @ text_feat.T).item()
    sim_adv_masked   = (model.encode_image(img_adv_masked)[:,0,:] @ text_feat.T).item()

    drop_clean_conf = sim_clean_orig - sim_clean_masked
    drop_adv_conf   = sim_adv_orig - sim_adv_masked

    misleading_clean = float(drop_clean_conf < drop_conf_threshold)
    misleading_adv   = float(drop_adv_conf < drop_conf_threshold)

    return {
        "CosSim_CLS": cos_cls,
        "MaxDeltaProb": delta_prob,
        "ConfDrop_Clean": drop_clean_conf,
        "ConfDrop_Adv": drop_adv_conf,
        "MisleadingRate_Clean": misleading_clean,
        "MisleadingRate_Adv": misleading_adv,
        "IoU_Topk": iou
    }

def batched_encode_text(model, texts, device, batch_size=64):
    all_features = []
    with torch.no_grad():
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            tokens = clip.tokenize(batch).to(device)
            feats = model.encode_text(tokens)
            feats = feats / feats.norm(dim=-1, keepdim=True)
            all_features.append(feats)
    return torch.cat(all_features, dim=0)
 