import os
import sys
from typing import Dict, List

import clip
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, wasserstein_distance, wilcoxon, ttest_rel
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode

from Utils import *
from Adv_attack import *
from Detection import *
from Evaluation import *
from Plots import *

CONFIG = {
    "dataset": "COCO",   # Flickr30K (dog) , COCO (cat ), ImageNet (bird), demo(bench)
    "base_model": "ViT-B/16",
    "surgery_model": "CS-ViT-B/16",
    "target_text_clean": ["cat"],
    "target_text_adv": "ground",
    "use_custom_preprocess": False
}

DATASET_ROOTS = {
    "Flickr30K": "F:/All codes/CLIP_Surgery/flick30Images/flickr30k_images/",
    "COCO":      "F:/All codes/CLIP_Surgery/coco-unlabeled2017/unlabeled2017/",
    "ImageNet":  "F:/All codes/CLIP_Surgery/image-net-val/",
    "Demo":      "F:/All codes/CLIP_Surgery/codes_paper/"
}

DATASET_DEFAULT_SAMPLE = {
    "Flickr30K": "459814265.jpg",
    "COCO":      "000000003599.jpg",
    "ImageNet":  "ILSVRC2012_val_00002023.JPEG",
    "Demo":      "demo.jpg"
}

# build image path
if CONFIG["dataset"] not in DATASET_ROOTS:
    raise ValueError(f"Dataset '{CONFIG['dataset']}' not configured!")

IMAGE_PATH = os.path.join(
    DATASET_ROOTS[CONFIG["dataset"]],
    DATASET_DEFAULT_SAMPLE[CONFIG["dataset"]]
)


device = "cuda" if torch.cuda.is_available() else "cpu"

def load_clip(model_name):
    print(f"Loading CLIP model: {model_name}")
    model, preprocess = clip.load(model_name, device=device)
    model.eval()
    return model, preprocess
model_name =CONFIG["base_model"]
model, base_preprocess = load_clip(CONFIG["base_model"])
model_surgery, _ = load_clip(CONFIG["surgery_model"])

BICUBIC = InterpolationMode.BICUBIC

custom_preprocess = Compose([
    Resize((224, 224), interpolation=BICUBIC),
    ToTensor(),
    Normalize((0.48145466, 0.4578275, 0.40821073),
              (0.26862954, 0.26130258, 0.27577711))
])
all_texts = [
    'airplane','bag','bed','bedclothes','bench','bicycle','bird','boat','book','bottle',
    'building','bus','cabinet','car','cat','ceiling','chair','cloth','computer','cow',
    'cup','curtain','dog','door','fence','floor','flower','food','grass','ground','horse',
    'keyboard','light','motorbike','mountain','mouse','person','plate','platform','potted plant',
    'road','rock','sheep','shelves','sidewalk','sign','sky','snow','sofa','table','track',
    'train','tree','truck','tv monitor','wall','water','window','wood','staire','desk',
    'cards','vespa','bear','banana','piano','this','is','a','this is a bench','this is a wall'
]
def preprocess_image(pil_img):
    if CONFIG["use_custom_preprocess"]:
        return custom_preprocess(pil_img).unsqueeze(0).to(device)
    return base_preprocess(pil_img).unsqueeze(0).to(device)

pil_img = Image.open(IMAGE_PATH).convert("RGB")
image = preprocess_image(pil_img)
cv2_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)

target_texts = CONFIG["target_text_clean"]
target_idx = [all_texts.index(t) for t in target_texts]

target_text = all_texts.index(CONFIG["target_text_adv"])

target_idx_adv = all_texts.index(CONFIG["target_text_adv"])


# --- Step 1: Encode text ---
with torch.no_grad():
    text_features = clip.encode_text_with_prompt_ensemble(model, all_texts, device)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    print("Text feature dim:", text_features.shape)

# --- Step 2: Adversarial attack ---
image_adv, history = generate_adversarial_image(
    image=image,
    model=model,
    text_features=text_features,
    target_idx=target_idx_adv,   # <-- push sim-map to "wall"
    
    preserve_idx=None, # <-- by adding a preserve idx we can oush sim map 
    # to whichever class we want as well,
    # here we only focus on preserving the clean clip output.
    steps=400,
    step_size=0.05,
    lambda_pred=0.5,
    lambda_entropy=0.5,
    lambda_margin=0.5,
    K_patches=100,
    l0_ratio=0.05,
    use_margin=True,
    use_surgery=False,
    model_surgery=model_surgery,
    surgery_temp=2.0,
    log_every=10)


with torch.no_grad():
    z = model.encode_image(image)
    if has_tokens(z):
        # tokens path
        z = z / z.norm(dim=-1, keepdim=True)                  # [1, 197, D]
        s = z @ text_features.T                                # [1, 197, N]
        Hm = Wm = int((z.shape[1]-1)**0.5); assert 1+Hm*Wm == z.shape[1]
        sim_map_clean = s[:,1:,:].reshape(1, Hm, Wm, -1)
    else:
        # CLS-only: build a patch map proxy via last attn or Grad-CAM++ (pick one)
        raise RuntimeError("encode_image returned CLS only; enable a tokens-returning forward or implement a patch-map method.")


# --- Step 3: Build similarity maps ---
with torch.no_grad():
    z_clean = encode_and_norm(model, image)
    z_adv   = encode_and_norm(model, image_adv)

    sim_map_clean = build_similarity_map(z_clean, text_features)
    sim_map_adv   = build_similarity_map(z_adv, text_features)

    z_clean_s = encode_and_norm(model_surgery, image)
    z_adv_s   = encode_and_norm(model_surgery, image_adv)

    sim_map_clean_surgery = build_similarity_map(z_clean_s, text_features)
    sim_map_adv_surgery   = build_similarity_map(z_adv_s, text_features)

# --- Step 4: Evaluate ---
metrics = evaluate_predictions(z_clean, z_adv, text_features, all_texts)
print(f"Clean prediction: {metrics['clean_pred']}")
print(f"Adversarial prediction: {metrics['adv_pred']}")
print(f"Global cosine similarity: {metrics['cosine_sim']:.3f}")
print(f"Max ΔProb: {metrics['max_delta_prob']:.3f}")


orig_img = denorm_clip(image).squeeze(0).permute(1,2,0).cpu().numpy()
adv_img  = denorm_clip(image_adv).squeeze(0).permute(1,2,0).cpu().numpy()

# --- Step 6: Evaluation summary ---
if __name__ == "__main__":
    results_clip = evaluate_image_pair(model, image, image_adv, text_features,
                                       all_texts.index(target_texts[0]),
                                       sim_map_clean, sim_map_adv)

    results_surgery = evaluate_image_pair(model_surgery, image, image_adv, text_features,
                                          all_texts.index(target_texts[0]),
                                          sim_map_clean_surgery, sim_map_adv_surgery)

    df = pd.DataFrame([results_clip, results_surgery], index=["CLIP","Proposed framwork"])
    print(df)
    print("\nAggregate Results:\n", df.describe().loc[["mean","std"]])

# ---------------------------
# Evaluation (Proposed Framework vs CLIP)
# ---------------------------
if __name__ == "__main__":

    with torch.no_grad():
        

        z_clean_pf = encode_and_norm(model_surgery, image)   # proposed framework
        z_adv_pf   = encode_and_norm(model_surgery, image_adv)
        sim_map_clean_pf = build_similarity_map(z_clean_pf, text_features)
        sim_map_adv_pf   = build_similarity_map(z_adv_pf, text_features)

        # ---- Metrics ----
        # 1) CosSim_CLS
        zc_cls = z_clean[:, 0, :]
        za_cls = z_adv[:, 0, :]
        CosSim_CLS = F.cosine_similarity(zc_cls, za_cls).item()

        # 2) MaxDeltaProb
        logits_c = zc_cls @ text_features.T
        logits_a = za_cls @ text_features.T
        probs_c, probs_a = torch.softmax(logits_c, dim=1), torch.softmax(logits_a, dim=1)
        MaxDeltaProb = (probs_a - probs_c).abs().max().item()

        # 3) IoU_Topk
        def topk_mask(sim_map, k=50):
            # Take similarity for target text class
            flat = sim_map[0, :, :, target_idx].reshape(-1)
            topk = flat.topk(k).indices
        
            # Initialize a boolean mask
            mask = torch.zeros_like(flat, dtype=torch.bool)
            mask[topk] = True
        
            return mask.view(sim_map.shape[1], sim_map.shape[2])  # shape [H, W]


        mask_clean = topk_mask(sim_map_clean, k=49)
        mask_adv   = topk_mask(sim_map_adv,   k=49)
        
        inter = (mask_clean & mask_adv).sum().item()
        union = (mask_clean | mask_adv).sum().item()
        IoU_Topk = inter / union if union > 0 else 0.0
        print("IoU_Topk:", IoU_Topk)


    # ---- Print results ----
    print("\n=== Metrics (CLIP) ===")
    print(f"CosSim_CLS:   {CosSim_CLS:.3f}")
    print(f"MaxDeltaProb: {MaxDeltaProb:.3f}")
    print(f"IoU_Topk:     {IoU_Topk:.3f}")

    # ---- Plot for comparison ----
    plot_all_together(
        model_name,
        orig_img_t=image, adv_img_t=image_adv, cv2_img=cv2_img,
        sim_map_clean=sim_map_clean, sim_map_adv=sim_map_adv,
        sim_map_clean_fs=sim_map_clean_pf, sim_map_adv_fs=sim_map_adv_pf,
        all_texts=all_texts, target_texts=target_texts
    )
    

plot_xai_transferability_fs_extended(
    model=model,
    model_surgery=model_surgery,
    model_name=model_name,
    cv2_img=cv2_img,
    image_clean=image,
    image_adv=image_adv,
    sim_map_clean=sim_map_clean,
    sim_map_adv=sim_map_adv,
    sim_map_clean_fs=sim_map_clean_pf,
    sim_map_adv_fs=sim_map_adv_pf,
    text_features=text_features,
    all_texts=all_texts,
    target_texts=target_texts,
)
