
import os
import sys
from typing import Dict, List

import clip
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, wasserstein_distance, wilcoxon, ttest_rel
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode



def _cosine_conf(sim: float) -> float:
    """map cosine sim [-1,1] -> confidence [0,1]"""
    return (1.0 + sim) / 2.0

def _pair_sim(model, image_batch_1x3HW: torch.Tensor, text_feat_1xD: torch.Tensor) -> float:
    """returns cosine similarity (float) between CLIP image and one text feature"""
    with torch.no_grad():
        i = model.encode_image(image_batch_1x3HW)
        i = i / i.norm(dim=-1, keepdim=True)
        t = text_feat_1xD / text_feat_1xD.norm(dim=-1, keepdim=True)
        s = (i @ t.T)[0, 0].item()
    return s

def _to_numpy_img(image_tensor: torch.Tensor) -> np.ndarray:
    """[1,3,H,W] or [3,H,W] -> float np image in [0,1]"""
    if image_tensor.dim() == 4:
        image_tensor = image_tensor.squeeze(0)
    img = image_tensor.detach().cpu().permute(1, 2, 0).numpy()
    # robust normalization (in case already normalized to CLIP mean/std)
    img = (img - img.min()) / max(1e-8, (img.max() - img.min()))
    return img
def _upsample_map(sim_map_2d: np.ndarray, H: int, W: int) -> np.ndarray:
    """resize heatmap to image size"""
    if sim_map_2d is None:
        raise ValueError("sim_map_2d is None before resizing")
    sim_map_2d = sim_map_2d.astype("float32")  # ensure cv2 accepts it
    return cv2.resize(sim_map_2d, (W, H), interpolation=cv2.INTER_LINEAR)


def _topk_binary_mask(sim_map_2d: np.ndarray, topk_ratio: float) -> np.ndarray:
    """return binary mask selecting top-k% pixels of the map"""
    flat = sim_map_2d.reshape(-1)
    k = max(1, int(topk_ratio * flat.size))
    thr = np.partition(flat, -k)[-k]
    mask = (sim_map_2d >= thr).astype(np.uint8)
    return mask

def _apply_mask(image_tensor: torch.Tensor,
                mask_2d_uint8: np.ndarray,
                mode: str = "zero",
                blur_kernel: int = 15) -> torch.Tensor:
    """
    image_tensor: [1,3,H,W]
    mask_2d_uint8: [H,W] 1=mask (to be removed), 0=keep
    mode: "zero" or "blur"
    """
    assert image_tensor.dim() == 4 and image_tensor.size(0) == 1
    _, _, H, W = image_tensor.shape
    assert mask_2d_uint8.shape == (H, W)

    img = image_tensor.detach().clone()
    img_np = _to_numpy_img(img)  # [H,W,3] in [0,1]

    if mode == "blur":
        # blur only masked area; keep others
        blur = cv2.GaussianBlur((img_np * 255).astype(np.uint8), (blur_kernel, blur_kernel), 0).astype(np.float32) / 255.0
        mask3 = np.repeat(mask_2d_uint8[:, :, None], 3, axis=2).astype(np.float32)
        mixed = img_np * (1.0 - mask3) + blur * mask3
    else:  # "zero"
        mask3 = np.repeat(mask_2d_uint8[:, :, None], 3, axis=2).astype(np.float32)
        mixed = img_np * (1.0 - mask3)  # zero out masked

    mixed_t = torch.from_numpy(mixed).permute(2, 0, 1).unsqueeze(0).to(image_tensor.device).type_as(image_tensor)
    return mixed_t



def visualize_explanation_attack(
    image_tensor: torch.Tensor,
    sim_map_2d: np.ndarray,
    sim_orig: float,
    sim_masked: float,
    image_tensor_adv: torch.Tensor,
    sim_map_2d_adv: np.ndarray,
    sim_orig_adv: float,
    sim_masked_adv: float,
    title: str = "",
    title_adv: str = "",
    topk_ratio: float = 0.10,
    drop_conf_threshold: float = 0.03,
    show: bool = True):
    """
    Show 4 plots in one row:
    Clean heatmap overlay, Clean masked (only if misleading),
    Adversarial heatmap overlay, Adversarial masked (only if misleading).
    """

    def _process_image(image_tensor, sim_map_2d, sim_orig, sim_masked, title, ax_heatmap, ax_topk):
        if image_tensor.dim() == 4:
            image_tensor = image_tensor.squeeze(0)
        img_np = _to_numpy_img(image_tensor)
    
        H, W = img_np.shape[:2]
        sim_map_res = _upsample_map(sim_map_2d, H, W)
        sim_map_norm = (sim_map_res - sim_map_res.min()) / max(1e-8, (sim_map_res.max() - sim_map_res.min()))
    
        # --- heatmap overlay ---
        hm = (sim_map_norm * 255).astype(np.uint8)
        hm_color = cv2.applyColorMap(hm, cv2.COLORMAP_JET)
        overlay = (img_np * 255 * 0.4 + hm_color * 0.6).astype(np.uint8)
        overlay = cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB)
    
        # --- confidence drop ---
        conf_o = _cosine_conf(sim_orig)
        conf_m = _cosine_conf(sim_masked)
        drop_conf = conf_o - conf_m
    
        # --- highlight only if misleading ---
        topk_mask = _topk_binary_mask(sim_map_res, topk_ratio=topk_ratio)
        if drop_conf < drop_conf_threshold:  # misleading
            highlight = img_np.copy()
            highlight[topk_mask == 1] = [1.0, 0.0, 0.0]
            ax_topk.imshow(highlight)
            ax_topk.set_title("Misleading patches")
        else:  # faithful
            ax_topk.imshow(img_np)  # show clean image
            ax_topk.set_title("Faithful (no highlight)")
    
        # --- plotting heatmap ---
        ax_heatmap.imshow(overlay)
        ax_heatmap.set_title(f"{title}\nHeatmap")
        ax_heatmap.axis('off')
        ax_topk.axis('off')
    
        return {
            "sim_orig": sim_orig,
            "sim_masked": sim_masked,
            "conf_orig": conf_o,
            "conf_masked": conf_m,
            "drop_sim": sim_orig - sim_masked,
            "drop_conf": drop_conf,
            "topk_mask": topk_mask,
            "is_misleading": float(drop_conf < drop_conf_threshold)
        }

    if show:
        fig, axs = plt.subplots(1, 4, figsize=(20, 5))

        # clean
        result_clean = _process_image(image_tensor, sim_map_2d, sim_orig, sim_masked,
                                      title=f"[CLEAN] {title}",
                                      ax_heatmap=axs[0], ax_topk=axs[1])

        # adv
        result_adv = None
        if image_tensor_adv is not None:
            result_adv = _process_image(image_tensor_adv, sim_map_2d_adv, sim_orig_adv, sim_masked_adv,
                                        title=f"[ADVERSARIAL] {title_adv}",
                                        ax_heatmap=axs[2], ax_topk=axs[3])

        plt.tight_layout()
        plt.show()

    return {"clean": result_clean, "adv": result_adv}


def detect_explanation_attack_for_labels(
    model,
    image_clean_1x3HW: torch.Tensor,
    image_adv_1x3HW: torch.Tensor,
    sim_map_clean_4d: torch.Tensor,  # [1,Hm,Wm,N]
    sim_map_adv_4d: torch.Tensor,    # [1,Hm,Wm,N]
    text_features_2d: torch.Tensor,  # [N, D]
    labels: List[str],
    target_labels: List[str],
    topk_ratio: float = 0.10,
    mask_mode: str = "blur",         # "blur" or "zero"
    drop_conf_threshold: float = 0.005,
    visualize: bool = True
) -> Dict[str, Dict[str, float]]:
    """
    For each target label:
      - mask top-k heatmap region on clean & adv
      - compute sim & confidence drops
      - visualize
      - return a summary dict
    """
    device = image_clean_1x3HW.device
    summary = {}

    # shapes
    Hm, Wm, N = sim_map_clean_4d.shape[1], sim_map_clean_4d.shape[2], sim_map_clean_4d.shape[3]
    H, W = image_clean_1x3HW.shape[2], image_clean_1x3HW.shape[3]

    for n in range(N):
        label = labels[n]
        if label not in target_labels:
            continue

        # (Hm, Wm) -> upsample to (H,W)
        sim_map_clean_np = sim_map_clean_4d[0, :, :, n].detach().cpu().numpy()
        sim_map_adv_np   = sim_map_adv_4d[0, :, :, n].detach().cpu().numpy()

        # ensure correct text feature shape [1, D]
        text_feat = text_features_2d[n].unsqueeze(0).to(device)

        # --- CLEAN ---
        # sims
        sim_clean_orig = _pair_sim(model, image_clean_1x3HW, text_feat)
        # masks
        clean_mask = _topk_binary_mask(_upsample_map(sim_map_clean_np, H, W), topk_ratio)
        img_clean_masked = _apply_mask(image_clean_1x3HW, clean_mask, mode=mask_mode)

        sim_clean_masked = _pair_sim(model, img_clean_masked, text_feat)
        sim_adv_orig = _pair_sim(model, image_adv_1x3HW, text_feat)
        adv_mask = _topk_binary_mask(_upsample_map(sim_map_adv_np, H, W), topk_ratio)
        img_adv_masked = _apply_mask(image_adv_1x3HW, adv_mask, mode=mask_mode)
        sim_adv_masked = _pair_sim(model, img_adv_masked, text_feat)

        # visualize (clean) & Adv
        # OLD
        if visualize:
            _ = visualize_explanation_attack(
                image_tensor=image_clean_1x3HW,
                sim_map_2d=sim_map_clean_np,
                sim_orig=sim_clean_orig,
                sim_masked=sim_clean_masked,
                title=f"[CLEAN] {label} — topk={int(topk_ratio*100)}%, mode={mask_mode}",
                
                image_tensor_adv=image_adv_1x3HW,
                sim_map_2d_adv=sim_map_adv_np,
                sim_orig_adv=sim_adv_orig,
                sim_masked_adv=sim_adv_masked,
                title_adv=f"[ADVERSARIAL] {label} — topk={int(topk_ratio*100)}%, mode={mask_mode}",
                
                topk_ratio=topk_ratio,
                drop_conf_threshold=0.005,
                show=True
            )


        
        # metrics
        conf_clean_orig  = _cosine_conf(sim_clean_orig)
        conf_clean_masked= _cosine_conf(sim_clean_masked)
        conf_adv_orig    = _cosine_conf(sim_adv_orig)
        conf_adv_masked  = _cosine_conf(sim_adv_masked)

        drop_clean_conf = conf_clean_orig - conf_clean_masked
        drop_adv_conf   = conf_adv_orig   - conf_adv_masked
        drop_clean_sim  = sim_clean_orig  - sim_clean_masked
        drop_adv_sim    = sim_adv_orig    - sim_adv_masked

        # flag misleading if masked doesn’t change much
        clean_flag = (drop_clean_conf < drop_conf_threshold)
        adv_flag   = (drop_adv_conf   < drop_conf_threshold)

        summary[label] = {
            "sim_clean_orig": sim_clean_orig,
            "sim_clean_masked": sim_clean_masked,
            "drop_clean_sim": drop_clean_sim,
            "drop_clean_conf": drop_clean_conf,
            "sim_adv_orig": sim_adv_orig,
            "sim_adv_masked": sim_adv_masked,
            "drop_adv_sim": drop_adv_sim,
            "drop_adv_conf": drop_adv_conf,
            "clean_misleading_flag": float(clean_flag),
            "adv_misleading_flag": float(adv_flag),
            "threshold_conf": drop_conf_threshold
        }

        print(
            f"[{label}]  "
            f"CLEAN drop_conf={drop_clean_conf:.4f}  ADV drop_conf={drop_adv_conf:.4f} "
            f"-> flags: clean={clean_flag} adv={adv_flag}"
        )

    return summary