
import os
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image

import clip
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode

# Your modules
from Utils import *
from Adv_attack import *
from Evaluation import *
from Plots import *

# ----------------------------
# Config
# ----------------------------

CONFIG = {
    "dataset": "Flickr30K",   # Flickr30K (dog) , COCO (cat ), ImageNet (bird), demo(bench)
    "base_model": "ViT-B/16",
    "surgery_model": "CS-ViT-B/16",
    "target_text_clean": ["dog"],
    "target_text_adv": "ground",
    "use_custom_preprocess": False
}


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BICUBIC = InterpolationMode.BICUBIC

MODELS = ["ViT-L/14", "ViT-B/16", "ViT-B/32"]  # source/target set
SURGERY_MODEL = None  # e.g. "CS-ViT-B/16" if you want surgery maps, else None


DATASET_ROOTS = {
    "Flickr30K": "F:/All codes/CLIP_Surgery/flick30Images/flickr30k_images/",
    "COCO":      "F:/All codes/CLIP_Surgery/coco-unlabeled2017/unlabeled2017/",
    "ImageNet":  "F:/All codes/CLIP_Surgery/image-net-val/",
    "Demo":      "F:/All codes/CLIP_Surgery/codes_paper/"
}

DATASET_DEFAULT_SAMPLE = {
    "Flickr30K": "459814265.jpg",
    "COCO":      "000000003599.jpg",
    "ImageNet":  "ILSVRC2012_val_00002023.JPEG",
    "Demo":      "demo.jpg"
}



IMAGE_PATH = os.path.join(
    DATASET_ROOTS[CONFIG["dataset"]],
    DATASET_DEFAULT_SAMPLE[CONFIG["dataset"]]
)

print("Using IMAGE_PATH:", IMAGE_PATH)

if not os.path.isfile(IMAGE_PATH):
    raise FileNotFoundError(f"Dataset sample not found: {IMAGE_PATH}")
    
    
ALL_TEXTS = list(dict.fromkeys([
    'airplane','bag','bed','bedclothes','bench','bicycle','bird','boat','book','bottle',
    'building','bus','cabinet','car','cat','ceiling','chair','cloth','computer','cow','cup','curtain',
    'dog','door','fence','floor','flower','food','grass','ground','horse','keyboard','light','motorbike',
    'mountain','mouse','person','plate','platform','potted plant','road','rock','sheep','shelves','sidewalk',
    'sign','sky','snow','sofa','table','track','train','tree','truck','tv monitor','wall','water','window',
    'wood','stair','desk','cards','vespa','bear','banana','piano'
]))

# Clean explanation target (for IoU evaluation)
target_texts = CONFIG["target_text_clean"]          # list of clean labels
TARGET_TEXT = target_texts[0]                       # single label for XAI plots
target_idx = ALL_TEXTS.index(TARGET_TEXT)           # integer index

# Evaluation knobs
PERCENT_K = 0.25          # select top 25% patches for IoU_Topk (model-invariant)
SOFT_TEMP = 1.0           # temperature for soft IoU
STEALTH_THRESHOLD = 0.98  # cosine threshold if you want a TSR later

# Attack hyperparams (same as your function signature)
ATTACK_CFG = dict(
    steps=400,
    step_size=0.05,
    lambda_pred=0.5,
    lambda_entropy=0.5,
    lambda_margin=0.5,
    K_patches=100,
    l0_ratio=0.05,
    use_margin=True,
    use_surgery=False,           # set True only if SURGERY_MODEL is provided and coded
    model_surgery=None,          # set later if SURGERY_MODEL
    surgery_temp=2.0,
    log_every=50,
)

# ----------------------------
# Preprocess
# ----------------------------
custom_preprocess = Compose([
    Resize((224, 224), interpolation=BICUBIC),
    ToTensor(),
    Normalize((0.48145466, 0.4578275, 0.40821073),
              (0.26862954, 0.26130258, 0.27577711))
])

def load_image(path: str) -> Tuple[torch.Tensor, np.ndarray]:
    pil_img = Image.open(path).convert("RGB")
    image = custom_preprocess(pil_img).unsqueeze(0).to(DEVICE)
    cv2_img = np.array(pil_img)[:, :, ::-1]  # BGR for your plot utils if needed
    return image, cv2_img

# ----------------------------
# Metrics (percent-k IoU, soft IoU, Spearman, Wasserstein)
# ----------------------------
def iou_topk_percent(sim_map_clean, sim_map_adv, class_idx: int, frac: float = 0.25) -> float:
    H, W = sim_map_clean.shape[1], sim_map_clean.shape[2]
    k = max(1, int(np.ceil(frac * H * W)))
    s1 = sim_map_clean[0, :, :, class_idx].reshape(-1)
    s2 = sim_map_adv[0,   :, :, class_idx].reshape(-1)
    top1 = torch.topk(s1, k).indices
    top2 = torch.topk(s2, k).indices
    m1 = torch.zeros_like(s1, dtype=torch.bool); m1[top1] = True
    m2 = torch.zeros_like(s2, dtype=torch.bool); m2[top2] = True
    inter = (m1 & m2).sum().item()
    uni   = (m1 | m2).sum().item()
    return inter / uni if uni else 0.0

def soft_iou(sim_map_clean, sim_map_adv, class_idx: int, temp: float = 1.0) -> float:
    s1 = sim_map_clean[0, :, :, class_idx].flatten()
    s2 = sim_map_adv[0,   :, :, class_idx].flatten()
    p1 = torch.softmax(s1 / temp, dim=0)
    p2 = torch.softmax(s2 / temp, dim=0)
    inter = torch.sum(torch.minimum(p1, p2)).item()
    union = torch.sum(torch.maximum(p1, p2)).item()
    return inter / union if union else 0.0

def spearman_and_emd(sim_map_clean, sim_map_adv, class_idx: int) -> Tuple[float,float]:
    flat_c = sim_map_clean[0, :, :, class_idx].reshape(-1).detach().cpu().numpy()
    flat_a = sim_map_adv[0,   :, :, class_idx].reshape(-1).detach().cpu().numpy()
    rho, _ = spearmanr(flat_c, flat_a)
    emd = wasserstein_distance(flat_c, flat_a)
    return float(rho), float(emd)

# ----------------------------
# Helpers
# ----------------------------
def get_text_features(model_clip, texts: List[str]) -> torch.Tensor:
    with torch.no_grad():
        tf = clip.encode_text_with_prompt_ensemble(model_clip, texts, DEVICE)
        tf = tf / tf.norm(dim=-1, keepdim=True)
    return tf

def build_maps_for(model_clip, image_1x3xHxW: torch.Tensor, text_feats: torch.Tensor):
    with torch.no_grad():
        z_clean = encode_and_norm(model_clip, image_1x3xHxW)
        sim_map = build_similarity_map(z_clean, text_feats)
    return z_clean, sim_map

def cls_cos_and_maxdeltaprob(z_clean_tokens, z_adv_tokens, text_feats):
    zc_cls = z_clean_tokens[:, 0, :]
    za_cls = z_adv_tokens[:, 0, :]
    cos = F.cosine_similarity(zc_cls, za_cls).item()
    logits_c = zc_cls @ text_feats.T
    logits_a = za_cls @ text_feats.T
    probs_c = torch.softmax(logits_c, dim=1)
    probs_a = torch.softmax(logits_a, dim=1)
    mdp = (probs_a - probs_c).abs().max().item()
    return float(cos), float(mdp)

def plot_transfer_matrix(matrix_df: pd.DataFrame, title="Transferability (mean IoU_Topk; lower=stronger)"):
    plt.figure(figsize=(6.2,5))
    ax = plt.gca()
    im = ax.imshow(matrix_df.values, vmin=0, vmax=1, aspect='auto')
    ax.set_xticks(range(matrix_df.shape[1])); ax.set_xticklabels(matrix_df.columns, rotation=30, ha='right')
    ax.set_yticks(range(matrix_df.shape[0])); ax.set_yticklabels(matrix_df.index)
    for (i,j), v in np.ndenumerate(matrix_df.values):
        ax.text(j, i, f"{v:.2f}", ha='center', va='center', fontsize=10, color='black')
    plt.colorbar(im, ax=ax, label="IoU_Topk")
    plt.title(title)
    plt.tight_layout()
    plt.savefig("transfer_matrix_iou.png", dpi=180)
    plt.show()

def plot_sim_heatmap(sim_map: torch.Tensor, class_idx: int, title: str, save_path: str = None):
    # sim_map: [1, H, W, N]
    s = sim_map[0, :, :, class_idx].detach().cpu().numpy()
    plt.figure(figsize=(4,4))
    plt.imshow(s, interpolation='nearest')
    plt.axis('off')
    plt.title(title)
    if save_path:
        plt.savefig(save_path, dpi=180, bbox_inches='tight')
    plt.show()

# ----------------------------
# Main evaluation
# ----------------------------
def main():
    os.makedirs("transfer_outputs", exist_ok=True)

    # Load input
    image, cv2_img = load_image(IMAGE_PATH)

    # Optional surgery model
    surgery_model = None
    if SURGERY_MODEL:
        surgery_model, _ = clip.load(SURGERY_MODEL, device=DEVICE)
        surgery_model.eval()

    # Load all models + per-model text features (text encoders differ by variant)
    models: Dict[str, any] = {}
    text_feats: Dict[str, torch.Tensor] = {}
    for name in MODELS:
        m, _ = clip.load(name, device=DEVICE)
        m.eval()
        models[name] = m
        text_feats[name] = get_text_features(m, ALL_TEXTS)

    target_texts = CONFIG["target_text_clean"]    # list
    TARGET_TEXT  = target_texts[0]                # string
    target_idx   = ALL_TEXTS.index(TARGET_TEXT)   # integer


    target_idx_adv = ALL_TEXTS.index(CONFIG["target_text_adv"])

    # Craft δ for each source model
    adv_images: Dict[str, torch.Tensor] = {}
    histories = {}
    for src_name in MODELS:
        print(f"\n[Attack] Crafting on source: {src_name}")
        model_src = models[src_name]
        attack_args = ATTACK_CFG.copy()
        if ATTACK_CFG.get("use_surgery", False) and surgery_model is not None:
            attack_args["model_surgery"] = surgery_model

        adv_img, hist = generate_adversarial_image(
            image=image,
            model=model_src,
            text_features=text_feats[src_name],
            target_idx=target_idx_adv,
            preserve_idx=None,
            **attack_args
        )
        adv_images[src_name] = adv_img
        histories[src_name] = hist

        # Save quick preview
        try:
            clean_np = denorm_clip(image).squeeze(0).permute(1,2,0).cpu().numpy()
            adv_np   = denorm_clip(adv_img).squeeze(0).permute(1,2,0).cpu().numpy()
            # plt.figure(figsize=(6,3))
            # plt.subplot(1,2,1); plt.imshow(clean_np); plt.axis('off'); plt.title(f"Clean ({src_name})")
            # plt.subplot(1,2,2); plt.imshow(adv_np);   plt.axis('off'); plt.title(f"Adv on {src_name}")
            # plt.tight_layout(); plt.savefig(f"transfer_outputs/preview_{src_name}.png", dpi=160); plt.show()
        except Exception as e:
            print("Preview save skipped:", e)

    # Evaluate transfer: for each (src δ) on each target model
    rows = []
    for src_name in MODELS:
        adv_img = adv_images[src_name]
        for tgt_name in MODELS:
            print(f"[Eval] Source δ: {src_name}  →  Target model: {tgt_name}")
            model_tgt = models[tgt_name]

            # Build maps on target model
            with torch.no_grad():
                z_clean_tgt = encode_and_norm(model_tgt, image)
                z_adv_tgt   = encode_and_norm(model_tgt, adv_img)
                sim_clean_tgt = build_similarity_map(z_clean_tgt, text_feats[tgt_name])
                sim_adv_tgt   = build_similarity_map(z_adv_tgt,   text_feats[tgt_name])

            # Metrics
            cos_cls, max_dp = cls_cos_and_maxdeltaprob(z_clean_tgt, z_adv_tgt, text_feats[tgt_name])
            iou_k = iou_topk_percent(sim_clean_tgt, sim_adv_tgt, target_idx, frac=PERCENT_K)
            iou_soft = soft_iou(sim_clean_tgt, sim_adv_tgt, target_idx, temp=SOFT_TEMP)
            rho, emd = spearman_and_emd(sim_clean_tgt, sim_adv_tgt, target_idx)
            
            rows.append(dict(
                source=src_name, target=tgt_name,
                CosSim_CLS=cos_cls, MaxDeltaProb=max_dp,
                IoU_Topk=iou_k, IoU_Soft=iou_soft, Spearman=rho, EMD=emd
            ))
            def safe_name(name: str) -> str:
                return name.replace("/", "-")
            # usage when saving
            src_safe = safe_name(src_name)
            tgt_safe = safe_name(tgt_name)
            plot_all_together(
                model_name=tgt_name,
                orig_img_t=image,
                adv_img_t=adv_img,
                cv2_img=cv2_img,
                sim_map_clean=sim_clean_tgt,
                sim_map_adv=sim_adv_tgt,
                sim_map_clean_fs=sim_clean_tgt,  # placeholder if FaithShield maps not computed
                sim_map_adv_fs=sim_adv_tgt,
                all_texts=ALL_TEXTS,
                target_texts=[TARGET_TEXT],
                src_model=src_name,              # NEW
                tgt_model=tgt_name               # NEW
            )



            # Save per-pair heatmaps
            # plot_sim_heatmap(sim_clean_tgt, target_idx, f"Clean {tgt_name} ({TARGET_TEXT})",
            #                  save_path=f"transfer_outputs/heatmap_clean_{src_safe}_to_{tgt_safe}.png")
            # plot_sim_heatmap(sim_adv_tgt, target_idx, f"Adv δ({src_safe}) on {tgt_safe} ({TARGET_TEXT})",
                             # save_path=f"transfer_outputs/heatmap_adv_{src_safe}_to_{tgt_safe}.png")

    results_df = pd.DataFrame(rows)
    results_df.to_csv("transfer_outputs/transfer_results.csv", index=False)
    print("\nSaved: transfer_outputs/transfer_results.csv")
    print(results_df)

    # Transfer matrix: mean IoU_Topk per (src→tgt)
    matrix = results_df.pivot_table(index="source", columns="target", values="IoU_Topk", aggfunc="mean")
    plot_transfer_matrix(matrix, title="Transferability Heatmap")

    # Optional: scatter stealth vs manipulation
    # plt.figure(figsize=(5,4))
    # x = results_df["CosSim_CLS"].values
    # y = 1.0 - results_df["IoU_Topk"].values
    # colors = [MODELS.index(t) for t in results_df["target"]]
    # plt.scatter(x, y, c=colors, alpha=0.7)
    # plt.axvline(STEALTH_THRESHOLD, ls='--', label=f"stealth τ={STEALTH_THRESHOLD}")
    # plt.xlabel("CosSim_CLS (stealth ↑)")
    # plt.ylabel("1 - IoU_Topk (manipulation ↑)")
    # plt.title("Stealth vs Manipulation")
    # plt.legend()
    # plt.tight_layout()
    # plt.savefig("transfer_outputs/stealth_vs_manip.png", dpi=160)
    # plt.show()
    # # ---- Plot for comparison ----
    

if __name__ == "__main__":
    main()


