import os
import sys
from typing import Dict, List

import clip
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, wasserstein_distance, wilcoxon, ttest_rel
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode

from Utils import *
from Adv_attack import *
from Detection import *
from Evaluation import *
from Plots import *

# Constants
BICUBIC = InterpolationMode.BICUBIC

# --- Load and preprocess image & model ---
# model_name = "ViT-L/14"  
# Alternatives:
# model_name = "ViT-B/32"
# model_name = "ViT-B/16"
model_name = "RN50"

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load base model
model, preprocess = clip.load(model_name, device=device)
model.eval()

# Load surgery model (choose one)
# surgery_model_name = "CS-ViT-L/14"  
# surgery_model_name = "CS-ViT-B/32"
# surgery_model_name = "CS-ViT-B/16"
surgery_model_name = "CS-RN50"
model_surgery, _ = clip.load(surgery_model_name, device=device)
model_surgery.eval()

# Custom preprocessing (overrides CLIP default if needed)
custom_preprocess = Compose([
    Resize((224, 224), interpolation=BICUBIC),
    ToTensor(),
    Normalize((0.48145466, 0.4578275, 0.40821073),
              (0.26862954, 0.26130258, 0.27577711))
])

# --- Load image ---

pil_img = Image.open("demo.jpg").convert("RGB")
cv2_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)

image = custom_preprocess(pil_img).unsqueeze(0).to(device)  # Shape: [1, 3, 224, 224]

# --- Texts ---
all_texts = [
    'airplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle', 
    'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow', 
    'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'horse', 
    'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform', 'potted plant', 
    'road', 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 
    'train', 'tree', 'truck', 'tv monitor', 'wall', 'water', 'window', 'wood', 'staire', 'desk', 
    'cards', 'vespa', 'bear', 'banana', 'piano', 'this', 'is', 'a', 'this is a bench', 'this is a wall']

target_texts = ['bench']
target_idx = [all_texts.index(t) for t in target_texts]

# Target for adversarial attack
target_text = "wall"
target_idx_adv = all_texts.index(target_text)



# # # pil_img = Image.open("demo.jpg")
# pil_img = Image.open("F:/All codes/CLIP_Surgery/flick30Images/flickr30k_images/459814265.jpg") #dog
# # # pil_img = Image.open("F:/All codes/CLIP_Surgery/image-net-val/ILSVRC2012_val_00002023.JPEG") #bird
# # # pil_img = Image.open("F:/All codes/CLIP_Surgery/coco-unlabeled2017/unlabeled2017/000000002115.jpg")
# # # pil_img = Image.open("F:/All codes/CLIP_Surgery/coco-unlabeled2017/unlabeled2017/000000002069.jpg")
# # # pil_img = Image.open("F:/All codes/CLIP_Surgery/coco-unlabeled2017/unlabeled2017/000000002651.jpg")

# # pil_img = Image.open("F:/All codes/CLIP_Surgery/coco-unlabeled2017/unlabeled2017/000000003599.jpg") #cat



# cv2_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
# image = preprocess(pil_img).unsqueeze(0).to(device)  # Shape: [1, 3, 224, 224]

# # --- Texts ---
# all_texts = ['airplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle', 
#              'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 
#              'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'horse', 'keyboard', 'light', 'motorbike', 
#              'mountain', 'mouse', 'person', 'plate', 'platform', 'potted plant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk', 
#              'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train', 'tree', 'truck', 'tv monitor', 'wall', 'water', 'window', 
#              'wood', 'staire', 'wall', 'desk','cards', 'vespa', 'bear','banana','piano']

# target_texts = ['dog']
# target_idx = [all_texts.index(t) for t in target_texts]
# #  target attack
# target_text = "ground"
# target_idx_adv = all_texts.index(target_text)



# Ensure consistent image size
img = pil_img.resize((224, 224))

# --- Step 1: Encode text ---
with torch.no_grad():
    text_features = clip.encode_text_with_prompt_ensemble(model, all_texts, device)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    print("Text feature dim:", text_features.shape)

# --- Step 2: Adversarial attack ---
image_adv, history = generate_adversarial_image(
    image=image,
    model=model,
    text_features=text_features,
    target_idx=target_idx_adv,   # <-- push sim-map to "wall"
    
    preserve_idx=None, # <-- by adding a preserve idx we can oush sim map 
    # to whichever class we want as well,
    # here we only focus on preserving the clean clip output.
    steps=400,
    step_size=0.05,
    lambda_pred=0.5,
    lambda_entropy=0.5,
    lambda_margin=0.5,
    K_patches=100,
    l0_ratio=0.05,
    use_margin=True,
    use_surgery=False,
    model_surgery=model_surgery,
    surgery_temp=2.0,
    log_every=10)


with torch.no_grad():
    z = model.encode_image(image)
    if has_tokens(z):
        # tokens path
        z = z / z.norm(dim=-1, keepdim=True)                  # [1, 197, D]
        s = z @ text_features.T                                # [1, 197, N]
        Hm = Wm = int((z.shape[1]-1)**0.5); assert 1+Hm*Wm == z.shape[1]
        sim_map_clean = s[:,1:,:].reshape(1, Hm, Wm, -1)
    else:
        # CLS-only: build a patch map proxy via last attn or Grad-CAM++ (pick one)
        raise RuntimeError("encode_image returned CLS only; enable a tokens-returning forward or implement a patch-map method.")


# --- Step 3: Build similarity maps ---
with torch.no_grad():
    z_clean = encode_and_norm(model, image)
    z_adv   = encode_and_norm(model, image_adv)

    sim_map_clean = build_similarity_map(z_clean, text_features)
    sim_map_adv   = build_similarity_map(z_adv, text_features)

    z_clean_s = encode_and_norm(model_surgery, image)
    z_adv_s   = encode_and_norm(model_surgery, image_adv)

    sim_map_clean_surgery = build_similarity_map(z_clean_s, text_features)
    sim_map_adv_surgery   = build_similarity_map(z_adv_s, text_features)

# --- Step 4: Evaluate ---
metrics = evaluate_predictions(z_clean, z_adv, text_features, all_texts)
print(f"Clean prediction: {metrics['clean_pred']}")
print(f"Adversarial prediction: {metrics['adv_pred']}")
print(f"Global cosine similarity: {metrics['cosine_sim']:.3f}")
print(f"Max ΔProb: {metrics['max_delta_prob']:.3f}")


orig_img = denorm_clip(image).squeeze(0).permute(1,2,0).cpu().numpy()
adv_img  = denorm_clip(image_adv).squeeze(0).permute(1,2,0).cpu().numpy()

# --- Step 6: Evaluation summary ---
if __name__ == "__main__":
    results_clip = evaluate_image_pair(model, image, image_adv, text_features,
                                       all_texts.index(target_texts[0]),
                                       sim_map_clean, sim_map_adv)

    results_surgery = evaluate_image_pair(model_surgery, image, image_adv, text_features,
                                          all_texts.index(target_texts[0]),
                                          sim_map_clean_surgery, sim_map_adv_surgery)

    df = pd.DataFrame([results_clip, results_surgery], index=["CLIP","Proposed framwork"])
    print(df)
    print("\nAggregate Results:\n", df.describe().loc[["mean","std"]])

# ---------------------------
# Evaluation (Proposed Framework vs CLIP)
# ---------------------------
if __name__ == "__main__":

    with torch.no_grad():
        

        z_clean_pf = encode_and_norm(model_surgery, image)   # proposed framework
        z_adv_pf   = encode_and_norm(model_surgery, image_adv)
        sim_map_clean_pf = build_similarity_map(z_clean_pf, text_features)
        sim_map_adv_pf   = build_similarity_map(z_adv_pf, text_features)

        # ---- Metrics ----
        # 1) CosSim_CLS
        zc_cls = z_clean[:, 0, :]
        za_cls = z_adv[:, 0, :]
        CosSim_CLS = F.cosine_similarity(zc_cls, za_cls).item()

        # 2) MaxDeltaProb
        logits_c = zc_cls @ text_features.T
        logits_a = za_cls @ text_features.T
        probs_c, probs_a = torch.softmax(logits_c, dim=1), torch.softmax(logits_a, dim=1)
        MaxDeltaProb = (probs_a - probs_c).abs().max().item()

        # 3) IoU_Topk
        def topk_mask(sim_map, k=50):
            # Take similarity for target text class
            flat = sim_map[0, :, :, target_idx].reshape(-1)
            topk = flat.topk(k).indices
        
            # Initialize a boolean mask
            mask = torch.zeros_like(flat, dtype=torch.bool)
            mask[topk] = True
        
            return mask.view(sim_map.shape[1], sim_map.shape[2])  # shape [H, W]


        mask_clean = topk_mask(sim_map_clean, k=49)
        mask_adv   = topk_mask(sim_map_adv,   k=49)
        
        inter = (mask_clean & mask_adv).sum().item()
        union = (mask_clean | mask_adv).sum().item()
        IoU_Topk = inter / union if union > 0 else 0.0
        print("IoU_Topk:", IoU_Topk)


    # ---- Print results ----
    print("\n=== Metrics (CLIP) ===")
    print(f"CosSim_CLS:   {CosSim_CLS:.3f}")
    print(f"MaxDeltaProb: {MaxDeltaProb:.3f}")
    print(f"IoU_Topk:     {IoU_Topk:.3f}")

    # ---- Plot for comparison ----
    plot_all_together(
        model_name,
        orig_img_t=image, adv_img_t=image_adv, cv2_img=cv2_img,
        sim_map_clean=sim_map_clean, sim_map_adv=sim_map_adv,
        sim_map_clean_fs=sim_map_clean_pf, sim_map_adv_fs=sim_map_adv_pf,
        all_texts=all_texts, target_texts=target_texts
    )
