import os
import sys
from typing import Dict, List

import clip
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, wasserstein_distance, wilcoxon, ttest_rel
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode

from Utils import *
from Adv_attack import *
from Detection import *
from Evaluation import *
from Plots import *

# Constants
BICUBIC = InterpolationMode.BICUBIC

# --- Load and preprocess image & model ---
# model_name = "ViT-L/14"  
# Alternatives:
# model_name = "ViT-B/32"
model_name = "ViT-B/16"

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load base model
model, preprocess = clip.load(model_name, device=device)
model.eval()

# Load surgery model (choose one)
# surgery_model_name = "CS-ViT-L/14"  
# surgery_model_name = "CS-ViT-B/32"
surgery_model_name = "CS-ViT-B/16"

model_surgery, _ = clip.load(surgery_model_name, device=device)
model_surgery.eval()

# Custom preprocessing (overrides CLIP default if needed)
custom_preprocess = Compose([
    Resize((224, 224), interpolation=BICUBIC),
    ToTensor(),
    Normalize((0.48145466, 0.4578275, 0.40821073),
              (0.26862954, 0.26130258, 0.27577711))
])

# --- Load image ---

pil_img = Image.open("demo.jpg").convert("RGB")
cv2_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)

image = custom_preprocess(pil_img).unsqueeze(0).to(device)  # Shape: [1, 3, 224, 224]

# --- Texts ---
all_texts = [
    'airplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle', 
    'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow', 
    'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'horse', 
    'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform', 'potted plant', 
    'road', 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 
    'train', 'tree', 'truck', 'tv monitor', 'wall', 'water', 'window', 'wood', 'staire', 'desk', 
    'cards', 'vespa', 'bear', 'banana', 'piano'
]

target_texts = ['bench']
target_idx = [all_texts.index(t) for t in target_texts]

# Target for adversarial attack
target_text = "wall"
target_idx_adv = all_texts.index(target_text)

# Ensure consistent image size
img = pil_img.resize((224, 224))





# --- Step 1: Encode text ---
with torch.no_grad():
    text_features = clip.encode_text_with_prompt_ensemble(model, all_texts, device)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    print("Text feature dim:", text_features.shape)

# --- Step 2: Adversarial attack ---

image_adv, history = generate_adversarial_image(
    image=image,
    model=model,
    text_features=text_features,
    target_idx=target_idx,
    steps=300,             # more steps
    step_size=3e-2,        # bigger update
    lambda_pred=1.0,       # but gets down-weighted inside loss
    lambda_entropy=0.5,
    lambda_margin=0.5,
    K_patches=50,          # more patches pushed
    l0_ratio=0.02,         # allow more pixels to move
    use_margin=True,
    use_surgery=True,
    model_surgery=model_surgery,
    surgery_temp=2.0,
    log_every=10)

with torch.no_grad():
    z = model.encode_image(image)
    if has_tokens(z):
        # tokens path
        z = z / z.norm(dim=-1, keepdim=True)                  # [1, 197, D]
        s = z @ text_features.T                                # [1, 197, N]
        Hm = Wm = int((z.shape[1]-1)**0.5); assert 1+Hm*Wm == z.shape[1]
        sim_map_clean = s[:,1:,:].reshape(1, Hm, Wm, -1)
    else:
        # CLS-only: build a patch map proxy via last attn or Grad-CAM++ (pick one)
        raise RuntimeError("encode_image returned CLS only; enable a tokens-returning forward or implement a patch-map method.")


# --- Step 3: Build similarity maps ---
with torch.no_grad():
    z_clean = encode_and_norm(model, image)
    z_adv   = encode_and_norm(model, image_adv)

    sim_map_clean = build_similarity_map(z_clean, text_features)
    sim_map_adv   = build_similarity_map(z_adv, text_features)

    z_clean_s = encode_and_norm(model_surgery, image)
    z_adv_s   = encode_and_norm(model_surgery, image_adv)

    sim_map_clean_surgery = build_similarity_map(z_clean_s, text_features)
    sim_map_adv_surgery   = build_similarity_map(z_adv_s, text_features)

# --- Step 4: Evaluate ---
metrics = evaluate_predictions(z_clean, z_adv, text_features, all_texts)
print(f"Clean prediction: {metrics['clean_pred']}")
print(f"Adversarial prediction: {metrics['adv_pred']}")
print(f"Global cosine similarity: {metrics['cosine_sim']:.3f}")
print(f"Max ΔProb: {metrics['max_delta_prob']:.3f}")


orig_img = denorm_clip(image).squeeze(0).permute(1,2,0).cpu().numpy()
adv_img  = denorm_clip(image_adv).squeeze(0).permute(1,2,0).cpu().numpy()

# --- Step 6: Evaluation summary ---
if __name__ == "__main__":
    results_clip = evaluate_image_pair(model, image, image_adv, text_features,
                                       all_texts.index(target_texts[0]),
                                       sim_map_clean, sim_map_adv)

    results_surgery = evaluate_image_pair(model_surgery, image, image_adv, text_features,
                                          all_texts.index(target_texts[0]),
                                          sim_map_clean_surgery, sim_map_adv_surgery)

    df = pd.DataFrame([results_clip, results_surgery], index=["CLIP","Proposed framwork"])
    print(df)
    print("\nAggregate Results:\n", df.describe().loc[["mean","std"]])

# ---------------------------
# Evaluation (Proposed Framework vs CLIP)
# ---------------------------
if __name__ == "__main__":

    with torch.no_grad():
        

        z_clean_pf = encode_and_norm(model_surgery, image)   # proposed framework
        z_adv_pf   = encode_and_norm(model_surgery, image_adv)
        sim_map_clean_pf = build_similarity_map(z_clean_pf, text_features)
        sim_map_adv_pf   = build_similarity_map(z_adv_pf, text_features)

        # ---- Metrics ----
        # 1) CosSim_CLS
        zc_cls = z_clean[:, 0, :]
        za_cls = z_adv[:, 0, :]
        CosSim_CLS = F.cosine_similarity(zc_cls, za_cls).item()

        # 2) MaxDeltaProb
        logits_c = zc_cls @ text_features.T
        logits_a = za_cls @ text_features.T
        probs_c, probs_a = torch.softmax(logits_c, dim=1), torch.softmax(logits_a, dim=1)
        MaxDeltaProb = (probs_a - probs_c).abs().max().item()

        # 3) IoU_Topk
        def topk_mask(sim_map, k=50):
            # Take similarity for target text class
            flat = sim_map[0, :, :, target_idx].reshape(-1)
            topk = flat.topk(k).indices
        
            # Initialize a boolean mask
            mask = torch.zeros_like(flat, dtype=torch.bool)
            mask[topk] = True
        
            return mask.view(sim_map.shape[1], sim_map.shape[2])  # shape [H, W]


        mask_clean = topk_mask(sim_map_clean, k=50)
        mask_adv   = topk_mask(sim_map_adv,   k=50)
        
        inter = (mask_clean & mask_adv).sum().item()
        union = (mask_clean | mask_adv).sum().item()
        IoU_Topk = inter / union if union > 0 else 0.0
        print("IoU_Topk:", IoU_Topk)


    # ---- Print results ----
    print("\n=== Metrics (CLIP) ===")
    print(f"CosSim_CLS:   {CosSim_CLS:.3f}")
    print(f"MaxDeltaProb: {MaxDeltaProb:.3f}")
    print(f"IoU_Topk:     {IoU_Topk:.3f}")

    # ---- Plot for comparison ----
    plot_all_together(
        model_name,
        orig_img_t=image, adv_img_t=image_adv, cv2_img=cv2_img,
        sim_map_clean=sim_map_clean, sim_map_adv=sim_map_adv,
        sim_map_clean_fs=sim_map_clean_pf, sim_map_adv_fs=sim_map_adv_pf,
        all_texts=all_texts, target_texts=target_texts
    )
