import os
import sys
import cv2
import torch
import clip
import numpy as np
import pandas as pd
import torch.nn.functional as F
import matplotlib.pyplot as plt

from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode

from Utils import *
from Adv_attack import *
from Evaluation import *
from Plots import *

CONFIG = {
    "dataset": "Demo",   # Flickr30K (dog) , COCO (cat ), ImageNet (bird), Demo(bench)
    "base_model": "ViT-B/16",
    "surgery_model": "CS-ViT-B/16",
    "target_text_clean": ["bench"],
    "target_text_adv": "ground",
    "use_custom_preprocess": False
}

DATASET_ROOTS = {
    "Flickr30K": "F:/All codes/CLIP_Surgery/flick30Images/flickr30k_images/",
    "COCO":      "F:/All codes/CLIP_Surgery/coco-unlabeled2017/unlabeled2017/",
    "ImageNet":  "F:/All codes/CLIP_Surgery/image-net-val/",
    "Demo":      "F:/All codes/CLIP_Surgery/codes_paper/"
}

DATASET_DEFAULT_SAMPLE = {
    "Flickr30K": "459814265.jpg",
    "COCO":      "000000003599.jpg",
    "ImageNet":  "ILSVRC2012_val_00002023.JPEG",
    "Demo":      "demo.jpg"
}

# build image path
if CONFIG["dataset"] not in DATASET_ROOTS:
    raise ValueError(f"Dataset '{CONFIG['dataset']}' not configured!")

IMAGE_PATH = os.path.join(
    DATASET_ROOTS[CONFIG["dataset"]],
    DATASET_DEFAULT_SAMPLE[CONFIG["dataset"]]
)


device = "cuda" if torch.cuda.is_available() else "cpu"

def load_clip(model_name):
    print(f"Loading CLIP model: {model_name}")
    model, preprocess = clip.load(model_name, device=device)
    model.eval()
    return model, preprocess
model_name =CONFIG["base_model"]
model, base_preprocess = load_clip(CONFIG["base_model"])
model_surgery, _ = load_clip(CONFIG["surgery_model"])
surgery_model_name = CONFIG["surgery_model"]
BICUBIC = InterpolationMode.BICUBIC

custom_preprocess = Compose([
    Resize((224, 224), interpolation=BICUBIC),
    ToTensor(),
    Normalize((0.48145466, 0.4578275, 0.40821073),
              (0.26862954, 0.26130258, 0.27577711))
])
all_texts = [
    'airplane','bag','bed','bedclothes','bench','bicycle','bird','boat','book','bottle',
    'building','bus','cabinet','car','cat','ceiling','chair','cloth','computer','cow',
    'cup','curtain','dog','door','fence','floor','flower','food','grass','ground','horse',
    'keyboard','light','motorbike','mountain','mouse','person','plate','platform','potted plant',
    'road','rock','sheep','shelves','sidewalk','sign','sky','snow','sofa','table','track',
    'train','tree','truck','tv monitor','wall','water','window','wood','staire','desk',
    'cards','vespa','bear','banana','piano','this','is','a','this is a bench','this is a wall'
]
def preprocess_image(pil_img):
    if CONFIG["use_custom_preprocess"]:
        return custom_preprocess(pil_img).unsqueeze(0).to(device)
    return base_preprocess(pil_img).unsqueeze(0).to(device)

pil_img = Image.open(IMAGE_PATH).convert("RGB")
image = preprocess_image(pil_img)
cv2_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)

target_texts = CONFIG["target_text_clean"]
target_idx = [all_texts.index(t) for t in target_texts]
target_idx_single = target_idx[0]
target_text = all_texts.index(CONFIG["target_text_adv"])

target_idx_adv = all_texts.index(CONFIG["target_text_adv"])


with torch.no_grad():
    text_features_clip = clip.encode_text_with_prompt_ensemble(model, all_texts, device)
    text_features_clip = text_features_clip / text_features_clip.norm(dim=-1, keepdim=True)

# =========================
# 4. GENERATE ADVERSARIAL IMAGE
# =========================
print("Generating adversarial image...")
image_adv, history = generate_adversarial_image(
    image=image,
    model=model,
    text_features=text_features_clip,
    target_idx=target_idx_adv,
    preserve_idx=None,
    steps=400,
    step_size=0.05,
    lambda_pred=0.5,
    lambda_entropy=0.5,
    lambda_margin=0.5,
    K_patches=100,
    l0_ratio=0.05,
    use_margin=True,
    use_surgery=False,
    model_surgery=None,
    surgery_temp=2.0,
    log_every=10
)

# =========================
# 5. BASELINE CLIP TOKENS & SIM MAPS (CLEAN + ADV)
# =========================
print("Computing baseline CLIP tokens & similarity maps (clean + adv)...")

with torch.no_grad():
    z_clean_clip = encode_and_norm(model, image)      # [B, T, D]
    z_adv_clip   = encode_and_norm(model, image_adv)

    sim_map_clean_clip = build_similarity_map(z_clean_clip, text_features_clip)  # [B, H, W, N]
    sim_map_adv_clip   = build_similarity_map(z_adv_clip,   text_features_clip)

# =========================
# 6. ABLATION TABLE (S1 / S2 / FS)
# =========================
ablation_table = [
    # Variant name, S1, S2, FS
    ("CLIP vanilla",                   False, False, False),
    ("CLIP vanilla + FS",              False, False, True ),

    ("FaithShield S1+S2",              True,  True,  False),
    ("FaithShield S1+S2 + FS",         True,  True,  True ),

    ("FaithShield S1-only",            True,  False, False),
    ("FaithShield S1-only + FS",       True,  False, True ),

    ("FaithShield S2-only",            False, True,  False),
    ("FaithShield S2-only + FS",       False, True,  True ),
]

# =========================
# 7. CACHE MODELS / TOKENS / SIM MAPS
# =========================
model_cache      = {}
z_clean_cache    = {}
z_adv_cache      = {}
text_feat_cache  = {}
sim_clean_cache  = {}
sim_adv_cache    = {}

# ---- Baseline CLIP ----
model_cache["CLIP"]      = model
z_clean_cache["CLIP"]    = z_clean_clip
z_adv_cache["CLIP"]      = z_adv_clip
text_feat_cache["CLIP"]  = text_features_clip
sim_clean_cache["CLIP"]  = sim_map_clean_clip
sim_adv_cache["CLIP"]    = sim_map_adv_clip

# ---- FaithShield (CLIPSurgery) models for all (S1,S2) combos except CLIP ----
unique_S_pairs = set((S1, S2) for _, S1, S2, _ in ablation_table if (S1 or S2))

for (S1, S2) in unique_S_pairs:
    key = f"S1{S1}_S2{S2}"
    print(f"\nLoading CLIPSurgery for key={key}: use_consistent_attention={S1}, use_skip_ffn={S2}")
    model_surg, _ = clip.load(
        surgery_model_name,
        device=device,
        use_consistent_attention=S1,
        use_skip_ffn=S2,
    )
    model_surg.eval()
    model_cache[key] = model_surg

    with torch.no_grad():
        z_clean_s = encode_and_norm(model_surg, image)
        z_adv_s   = encode_and_norm(model_surg, image_adv)

        z_clean_cache[key] = z_clean_s
        z_adv_cache[key]   = z_adv_s

        text_features_surg = clip.encode_text_with_prompt_ensemble(model_surg, all_texts, device)
        text_features_surg = text_features_surg / text_features_surg.norm(dim=-1, keepdim=True)
        text_feat_cache[key] = text_features_surg

        sim_map_clean_s = build_similarity_map(z_clean_s, text_features_surg)
        sim_map_adv_s   = build_similarity_map(z_adv_s,   text_features_surg)

        sim_clean_cache[key] = sim_map_clean_s
        sim_adv_cache[key]   = sim_map_adv_s

# =========================
# 8. FEATURE SURGERY MAPS (for all FS=True variants)
# =========================
print("\nComputing Feature Surgery maps...")
fs_clean_cache = {}
fs_adv_cache   = {}

with torch.no_grad():
    for variant_name, S1, S2, FS in ablation_table:
        if not FS:
            continue

        if S1 or S2:
            key = f"S1{S1}_S2{S2}"
        else:
            key = "CLIP"

        zc = z_clean_cache[key]
        za = z_adv_cache[key]
        tf = text_feat_cache[key]

        sim_FS_clean = clip.clip_feature_surgery(zc, tf)   # [B, T, N]
        sim_FS_adv   = clip.clip_feature_surgery(za, tf)

        fs_clean_cache[variant_name] = clip.get_similarity_map(
            sim_FS_clean[:, 1:, :], cv2_img.shape[:2]
        )
        fs_adv_cache[variant_name] = clip.get_similarity_map(
            sim_FS_adv[:, 1:, :], cv2_img.shape[:2]
        )

# =========================
# 9. FULL METRICS TABLE (CLEAN + ADV) via evaluate_image_pair
# =========================
print("\nEvaluating all variants (clean+adv metrics)...\n")

rows = []

for variant_name, S1, S2, FS in ablation_table:

    # ----- choose model & text features -----
    if S1 or S2:
        key = f"S1{S1}_S2{S2}"
        model_v = model_cache[key]
        tfeat   = text_feat_cache[key]

        if FS:
            sim_clean = fs_clean_cache[variant_name]
            sim_adv   = fs_adv_cache[variant_name]
        else:
            sim_clean = sim_clean_cache[key]
            sim_adv   = sim_adv_cache[key]
    else:
        # CLIP baseline variants
        model_v = model_cache["CLIP"]
        tfeat   = text_feat_cache["CLIP"]

        if FS:
            sim_clean = fs_clean_cache[variant_name]      # "CLIP vanilla + FS"
            sim_adv   = fs_adv_cache[variant_name]
        else:
            sim_clean = sim_clean_cache["CLIP"]
            sim_adv   = sim_adv_cache["CLIP"]

    # ----- compute metrics -----
    metrics = evaluate_image_pair(
        model_v,
        image,
        image_adv,
        tfeat,
        target_idx_single,
        sim_clean,
        sim_adv
    )
    metrics["S1"] = S1
    metrics["S2"] = S2
    metrics["FS"] = FS
    metrics["Variant"] = variant_name

    rows.append(metrics)

df_full = pd.DataFrame(rows).set_index("Variant")
cols_order = ["S1", "S2", "FS"] + [c for c in df_full.columns if c not in ["S1", "S2", "FS"]]
df_full = df_full[cols_order]

print("\n========== FULL CLEAN+ADV METRICS (EVALUATE_IMAGE_PAIR) ==========\n")
print(df_full)
print("\n----------------- Aggregate (mean / std) -----------------\n")
print(df_full.describe().loc[["mean", "std"]])
print("\n==========================================================\n")

# =========================
# 10. VISUALIZATION EXAMPLES (HEATMAPS)
# =========================
orig_img = denorm_clip(image).squeeze(0).permute(1, 2, 0).cpu().numpy()
adv_img  = denorm_clip(image_adv).squeeze(0).permute(1, 2, 0).cpu().numpy()

# Helper: safely get sim maps for a given variant, defaulting to non-FS if needed
def get_sim_maps_for_variant(variant_name_base):
    """
    variant_name_base: e.g. "CLIP vanilla", "FaithShield S1+S2"
    returns: sim_clean_base, sim_adv_base, sim_clean_fs, sim_adv_fs
    """
    if variant_name_base == "CLIP vanilla":
        sim_clean_base = sim_clean_cache["CLIP"]
        sim_adv_base   = sim_adv_cache["CLIP"]
        sim_clean_fs   = fs_clean_cache["CLIP vanilla + FS"]
        sim_adv_fs     = fs_adv_cache["CLIP vanilla + FS"]

    elif variant_name_base == "FaithShield S1+S2":
        key = "S1True_S2True"
        sim_clean_base = sim_clean_cache[key]
        sim_adv_base   = sim_adv_cache[key]
        sim_clean_fs   = fs_clean_cache["FaithShield S1+S2 + FS"]
        sim_adv_fs     = fs_adv_cache["FaithShield S1+S2 + FS"]
    else:
        raise ValueError(f"Unknown base variant for visualization: {variant_name_base}")

    return sim_clean_base, sim_adv_base, sim_clean_fs, sim_adv_fs

# --- Example 1: CLIP vanilla vs CLIP vanilla + FS ---
sim_clean_base, sim_adv_base, sim_clean_fs, sim_adv_fs = get_sim_maps_for_variant("CLIP vanilla")

plot_all_together(
    model_name="CLIP vanilla vs CLIP vanilla + FS",
    orig_img_t=image,
    adv_img_t=image_adv,
    cv2_img=cv2_img,
    sim_map_clean=sim_clean_base,
    sim_map_adv=sim_adv_base,
    sim_map_clean_fs=sim_clean_fs,
    sim_map_adv_fs=sim_adv_fs,
    all_texts=all_texts,
    target_texts=target_texts
)

# --- Example 2: FaithShield S1+S2 vs FaithShield S1+S2 + FS ---
sim_clean_base, sim_adv_base, sim_clean_fs, sim_adv_fs = get_sim_maps_for_variant("FaithShield S1+S2")

plot_all_together(
    model_name="FaithShield S1+S2 vs FaithShield S1+S2 + FS",
    orig_img_t=image,
    adv_img_t=image_adv,
    cv2_img=cv2_img,
    sim_map_clean=sim_clean_base,
    sim_map_adv=sim_adv_base,
    sim_map_clean_fs=sim_clean_fs,
    sim_map_adv_fs=sim_adv_fs,
    all_texts=all_texts,
    target_texts=target_texts
)

print("\nDone. You now have:")
print("- 8 variants (CLIP, FaithShield S1/S2, FS) with S1/S2/FS flags")
print("- A full clean+adv metrics table (df_full)")
print("- Heatmap plots for CLIP vs CLIP+FS and S1+S2 vs S1+S2+FS.\n")
