import torch
import numpy as np
import clip
import os
import sys
from typing import Dict, List

import clip
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, wasserstein_distance, wilcoxon, ttest_rel
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
from adaptive_attack import generate_adaptive_attack
from Utils import *
from Adv_attack import *
from Detection import *
from Evaluation import *
from Plots import *

# Constants
BICUBIC = InterpolationMode.BICUBIC

# --- Load and preprocess image & model ---
# model_name = "ViT-L/14"  
# Alternatives:
# model_name = "ViT-B/32"
model_name = "ViT-B/16"
# model_name = "RN50"

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load base model
model, preprocess = clip.load(model_name, device=device)
model.eval()

# Load surgery model (choose one)
# surgery_model_name = "CS-ViT-L/14"  
# surgery_model_name = "CS-ViT-B/32"
surgery_model_name = "CS-ViT-B/16"
# surgery_model_name = "CS-RN50"
model_surgery, _ = clip.load(surgery_model_name, device=device)
model_surgery.eval()

# Custom preprocessing (overrides CLIP default if needed)
custom_preprocess = Compose([
    Resize((224, 224), interpolation=BICUBIC),
    ToTensor(),
    Normalize((0.48145466, 0.4578275, 0.40821073),
              (0.26862954, 0.26130258, 0.27577711))
])

# --- Load image ---

pil_img = Image.open("demo.jpg").convert("RGB")
cv2_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)

image = custom_preprocess(pil_img).unsqueeze(0).to(device)  # Shape: [1, 3, 224, 224]

# --- Texts ---
all_texts = [
    'airplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle', 
    'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow', 
    'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'horse', 
    'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform', 'potted plant', 
    'road', 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 
    'train', 'tree', 'truck', 'tv monitor', 'wall', 'water', 'window', 'wood', 'staire', 'desk', 
    'cards', 'vespa', 'bear', 'banana', 'piano', 'this', 'is', 'a', 'this is a bench', 'this is a wall']

target_texts = ['bench']
target_idx = [all_texts.index(t) for t in target_texts]
target_idx_stage2 = all_texts.index('bench')

# Target for adversarial attack
target_text = "wall"
target_idx_adv = all_texts.index(target_text)

# Ensure consistent image size
img = pil_img.resize((224, 224))

# --- Step 1: Encode text ---
with torch.no_grad():
    text_features = clip.encode_text_with_prompt_ensemble(model, all_texts, device)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    print("Text feature dim:", text_features.shape)

# --- Step 2: Adversarial attack ---
image_adv, history = generate_adversarial_image(
    image=image,
    model=model,
    text_features=text_features,
    target_idx=target_idx_adv,   # <-- push sim-map to "wall"
    
    preserve_idx=None, # <-- by adding a preserve idx we can oush sim map 
    # to whichever class we want as well,
    # here we only focus on preserving the clean clip output.
    steps=400,
    step_size=0.05,
    lambda_pred=0.5,
    lambda_entropy=0.5,
    lambda_margin=0.5,
    K_patches=100,
    l0_ratio=0.05,
    use_margin=True,
    use_surgery=False,
    model_surgery=model_surgery,
    surgery_temp=2.0,
    log_every=10)



print("\n=== Running Adaptive Attack (FaithShield-aware) ===")

# Encode baseline text features
with torch.no_grad():
    text_features_clip = clip.encode_text_with_prompt_ensemble(model, all_texts, device)
    text_features_clip /= text_features_clip.norm(dim=-1, keepdim=True)

# Run adaptive attack
image_adv_adaptive, history_adaptive = generate_adaptive_attack(
    model=model,
    model_surgery=model_surgery,      # S1+S2 + FS-aware
    image=image,
    text_features=text_features_clip,
    target_idx=target_idx_adv,
    steps=300,
    K_patches=50,
    l0_ratio=0.01,
    verbose=True
)

# ==========================================================
#   FIXED & CORRECT FAITHSHIELD IMPLEMENTATION
# ==========================================================

import torch
import torch.nn.functional as F
import numpy as np
import cv2
import pandas as pd


# ----------------------------------------------------------
# Utility: Normalize Image Tokens
# ----------------------------------------------------------
def encode_and_norm(model, image):
    z = model.encode_image(image)
    z = z / (z.norm(dim=-1, keepdim=True) + 1e-6)
    return z


# ----------------------------------------------------------
# Build patch-level similarity map (Stage I core)
# ----------------------------------------------------------
def build_similarity_map(z, text_features):
    if z.dim() != 3:
        raise RuntimeError("Model must output CLS+patch tokens (ViT).")

    T = z.shape[1]
    P = T - 1
    H = W = int(P ** 0.5)
    assert 1 + H * W == T

    s = z @ text_features.T
    sim = s[:, 1:, :].reshape(1, H, W, -1)
    return sim


# ----------------------------------------------------------
# Make Stage-I Top-K mask (using SURGERY model only)
# ----------------------------------------------------------
def make_topk_mask(sim_map, target_idx, k_ratio=0.10):
    H, W = sim_map.shape[1], sim_map.shape[2]
    flat = sim_map[0, :, :, target_idx].reshape(-1)

    k = max(1, int(len(flat) * k_ratio))
    topk = flat.topk(k).indices

    mask = torch.zeros_like(flat, dtype=torch.bool)
    mask[topk] = True
    return mask.reshape(H, W)


# ----------------------------------------------------------
# Upsample 14×14 mask → 224×224 pixels
# ----------------------------------------------------------
def upsample_mask(mask_hw, H_full=224, W_full=224):
    mask_np = mask_hw.cpu().numpy().astype(np.float32)
    up = cv2.resize(mask_np, (W_full, H_full), interpolation=cv2.INTER_NEAREST)
    return (up > 0.5).astype(np.uint8)


# ----------------------------------------------------------
# Apply pixel mask to image
# ----------------------------------------------------------
def apply_mask(image, mask_full):
    mask3 = torch.tensor(mask_full).to(image.device)[None, None, :, :]
    mask3 = mask3.repeat(1, 3, 1, 1)
    return image * (1 - mask3)


# ----------------------------------------------------------
# Stage-II — Compute BASE CLIP confidence (NOT surgery!)
# ----------------------------------------------------------
def compute_confidence_clip(model, image, text_features):
    with torch.no_grad():
        z = model.encode_image(image)          # [1,D] or [1,T,D]
        if z.dim() == 3:
            z = z[:, 0, :]                    # CLS token

        z = z / (z.norm(dim=-1, keepdim=True) + 1e-6)

        logits = z @ text_features.T
        probs = torch.softmax(logits, dim=-1)

        # confidence = max probability
        conf = float(probs.max())
        return conf


def compute_iou(mask1, mask2):
    m1 = mask1.cpu().numpy()
    m2 = mask2.cpu().numpy()
    inter = np.logical_and(m1, m2).sum()
    union = np.logical_or(m1, m2).sum()
    return float(inter / (union + 1e-6))

# ==========================================================
#  FaithShield Stage-II Confidence (Vanilla CLIP)
# ==========================================================
def compute_cls_confidence(model, image, text_features, target_idx):
    """
    Computes cosine(CLS, text[target]) and converts to confidence in [0,1].
    Works for both CLS-only models and ViT token models.
    """
    with torch.no_grad():
        z = model.encode_image(image)

        # Normalize embeddings
        z = z / (z.norm(dim=-1, keepdim=True) + 1e-6)

        # Case 1: CLS-only (RN50, some Surgery models)
        if z.dim() == 2:
            CLS = z  # [1, D]

        # Case 2: ViT token output → CLS = first token
        elif z.dim() == 3:
            CLS = z[:, 0, :]  # [1, D]

        else:
            raise ValueError(f"Unexpected encode_image shape: {z.shape}")

        # Normalize target text embedding
        t = text_features[target_idx].unsqueeze(0)
        t = t / (t.norm() + 1e-6)

        # Cosine similarity
        cos = float((CLS @ t.T).item())

        # FaithShield converts from [-1,1] → [0,1]
        conf = (1 + cos) / 2  

        return cos, conf

# ==========================================================
#   FaithShield Evaluator 
# ==========================================================
def evaluate_faithshield(
    model,              # vanilla CLIP
    model_surgery,      # CLIP-Surgery (Stage-1 robustness)
    image_clean,
    image_adv,
    text_features,
    target_idx,
    k_ratio=0.10
):
    # --------------------------------------------------
    # Stage-I (Explanation Robustness)  — MUST USE SURGERY
    # --------------------------------------------------
    with torch.no_grad():
        z_clean_s = encode_and_norm(model_surgery, image_clean)
        z_adv_s   = encode_and_norm(model_surgery, image_adv)

        sim_clean_s = build_similarity_map(z_clean_s, text_features)
        sim_adv_s   = build_similarity_map(z_adv_s,   text_features)

    # Masks from SURGERY ONLY  (correct)
    mask_clean = make_topk_mask(sim_clean_s, target_idx, k_ratio)
    mask_adv   = make_topk_mask(sim_adv_s,   target_idx, k_ratio)

    IoU = compute_iou(mask_clean, mask_adv)

    # Upsample mask for Stage-II
    H_full, W_full = image_clean.shape[2], image_clean.shape[3]
    mask_adv_full = upsample_mask(mask_adv, H_full, W_full)

    # --------------------------------------------------
    # Stage-II (Confidence Drop)  — MUST USE VANILLA CLIP
    # --------------------------------------------------
    # --- Original confidence (CLS-text similarity) ---
    cos_orig, conf_orig = compute_cls_confidence(
        model, image_adv, text_features, target_idx
    )

    # --- Apply Stage-I mask (zero important pixels) ---
    masked_adv = apply_mask(image_adv, mask_adv_full)

    # --- Confidence after masking ---
    cos_mask, conf_mask = compute_cls_confidence(
        model, masked_adv, text_features, target_idx
    )

    FS_ConfDrop = float(conf_orig - conf_mask)

    # --------------------------------------------------
    # Classification Stability Metrics (Vanilla CLIP)
    # --------------------------------------------------
    with torch.no_grad():

        # CLS tokens (vanilla CLIP)
        z_clean = encode_and_norm(model, image_clean)[:, 0, :]
        z_adv   = encode_and_norm(model, image_adv)[:, 0, :]

        # Correct CLS cosine drop = clean − adv
        CosSim_CLS_Drop = float(
            F.cosine_similarity(z_clean, z_adv).item()
        )

        # Softmax label shift
        pc = torch.softmax(z_clean @ text_features.T, dim=-1)
        pa = torch.softmax(z_adv   @ text_features.T, dim=-1)
        MaxDeltaProb = float((pa - pc).abs().max())

    # --------------------------------------------------
    # Output dictionary
    # --------------------------------------------------
    return {
        "CosSim_CLS_Drop": CosSim_CLS_Drop,
        "MaxDeltaProb": MaxDeltaProb,
        "IoU_TopK": IoU,
        "FS_ConfDrop": FS_ConfDrop,
        "SimOrig": conf_orig,
        "SimMasked": conf_mask
    }



# ----------------------------------------------------------
# Evaluate BOTH attacks (Non-adaptive + Adaptive)
# ----------------------------------------------------------
def evaluate_both_attacks(
    model, model_surgery,
    image_clean,
    image_adv_nonadaptive,
    image_adv_adaptive,
    text_features,
    target_idx
):
    rows = []

    for name, img in [
        ("Non-Adaptive", image_adv_nonadaptive),
        ("Adaptive",     image_adv_adaptive)
    ]:
        m = evaluate_faithshield(
            model, model_surgery,
            image_clean, img,
            text_features, target_idx
        )
        m["Attacker"] = name
        rows.append(m)

    return pd.DataFrame(rows)


df = evaluate_both_attacks(
    model, model_surgery,
    image,                 # ✔ correct
    image_adv,             # ✔
    image_adv_adaptive,    # ✔
    text_features,
    target_idx_stage2        
)


print(df)

def get_token_maps(model, image, text_features):
    """Return z_tokens and similarity map for one image."""
    with torch.no_grad():
        z = model.encode_image(image)           # [1, 197, D] for ViT
        z = z / z.norm(dim=-1, keepdim=True)

        # detect tokens
        if z.dim() != 3 or z.shape[1] <= 1:
            raise RuntimeError("encode_image returned CLS-only (no tokens).")

        T = z.shape[1]
        Hm = Wm = int((T - 1)**0.5)

        s = z @ text_features.T                # [1, 197, N]
        sim_map = s[:, 1:, :].reshape(1, Hm, Wm, -1)

        return z, sim_map



def plot_attack_heatmaps(
    model, 
    model_surgery,
    image_clean,
    adv_nonadaptive,
    adv_adaptive,
    text_features,
    all_texts,
    target_texts,
    cv2_img
):
    target_idx = all_texts.index(target_texts[0])


    # ----------------------------------------------------
    # CLEAN maps
    # ----------------------------------------------------
    print("Computing CLEAN maps...")
    z_clean, sim_map_clean = get_token_maps(model, image_clean, text_features)
    z_clean_s, sim_map_clean_s = get_token_maps(model_surgery, image_clean, text_features)


    # ----------------------------------------------------
    # NON-ADAPTIVE maps
    # ----------------------------------------------------
    print("Computing NON-ADAPTIVE maps...")
    z_adv_na, sim_map_adv_na = get_token_maps(model, adv_nonadaptive, text_features)
    z_adv_na_s, sim_map_adv_na_s = get_token_maps(model_surgery, adv_nonadaptive, text_features)


    # ----------------------------------------------------
    # ADAPTIVE maps
    # ----------------------------------------------------
    print("Computing ADAPTIVE maps...")
    z_adv_ad, sim_map_adv_ad = get_token_maps(model, adv_adaptive, text_features)
    z_adv_ad_s, sim_map_adv_ad_s = get_token_maps(model_surgery, adv_adaptive, text_features)


    # ----------------------------------------------------
    # Recover numpy images for display
    # ----------------------------------------------------
    orig_np = denorm_clip(image_clean).squeeze().permute(1,2,0).cpu().numpy()
    na_np   = denorm_clip(adv_nonadaptive).squeeze().permute(1,2,0).cpu().numpy()
    ad_np   = denorm_clip(adv_adaptive).squeeze().permute(1,2,0).cpu().numpy()


    # ----------------------------------------------------
    # PLOTS
    # ----------------------------------------------------
    print("\n=== Plotting NON-ADAPTIVE ATTACK ===")
    plot_all_together(
        model_name="Non-Adaptive Attack",
        orig_img_t=image_clean,
        adv_img_t=adv_nonadaptive,
        cv2_img=cv2_img,
        sim_map_clean=sim_map_clean,
        sim_map_adv=sim_map_adv_na,
        sim_map_clean_fs=sim_map_clean_s,
        sim_map_adv_fs=sim_map_adv_na_s,
        all_texts=all_texts,
        target_texts=target_texts
    )

    print("\n=== Plotting ADAPTIVE ATTACK (FaithShield-aware) ===")
    plot_all_together(
        model_name="Adaptive Attack",
        orig_img_t=image_clean,
        adv_img_t=adv_adaptive,
        cv2_img=cv2_img,
        sim_map_clean=sim_map_clean,
        sim_map_adv=sim_map_adv_ad,
        sim_map_clean_fs=sim_map_clean_s,
        sim_map_adv_fs=sim_map_adv_ad_s,
        all_texts=all_texts,
        target_texts=target_texts
    )

    print("\nDone! Heatmaps generated successfully.\n")

plot_attack_heatmaps(
    model=model,
    model_surgery=model_surgery,
    image_clean=image,
    adv_nonadaptive=image_adv,
    adv_adaptive=image_adv_adaptive,
    text_features=text_features_clip,
    all_texts=all_texts,
    target_texts=target_texts,
    cv2_img=cv2_img
)
