import os
import sys
from typing import Dict, List

import clip
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, wasserstein_distance, wilcoxon, ttest_rel
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode

from Utils import *
from Adv_attack import *
from Detection import *
from Evaluation import *
from Plots import *

# Constants
BICUBIC = InterpolationMode.BICUBIC

# --- Load and preprocess image & model ---
# model_name = "ViT-L/14"  
# Alternatives:
# model_name = "ViT-B/32"
model_name = "ViT-B/16"
# model_name = "RN50"

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load base model
model, preprocess = clip.load(model_name, device=device)
model.eval()

# Load surgery model (choose one)
# surgery_model_name = "CS-ViT-L/14"  
# surgery_model_name = "CS-ViT-B/32"
surgery_model_name = "CS-ViT-B/16"
# surgery_model_name = "CS-RN50"
model_surgery, _ = clip.load(surgery_model_name, device=device)
model_surgery.eval()

# Custom preprocessing (overrides CLIP default if needed)
custom_preprocess = Compose([
    Resize((224, 224), interpolation=BICUBIC),
    ToTensor(),
    Normalize((0.48145466, 0.4578275, 0.40821073),
              (0.26862954, 0.26130258, 0.27577711))
])

# --- Load image ---

pil_img = Image.open("demo.jpg").convert("RGB")
cv2_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)

image = custom_preprocess(pil_img).unsqueeze(0).to(device)  # Shape: [1, 3, 224, 224]

# --- Texts ---
all_texts = [
    'airplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle', 
    'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow', 
    'cup', 'curtain', 'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'horse', 
    'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person', 'plate', 'platform', 'potted plant', 
    'road', 'rock', 'sheep', 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 'track', 
    'train', 'tree', 'truck', 'tv monitor', 'wall', 'water', 'window', 'wood', 'staire', 'desk', 
    'cards', 'vespa', 'bear', 'banana', 'piano', 'this', 'is', 'a', 'this is a bench', 'this is a wall']

target_texts = ['bench']
target_idx = [all_texts.index(t) for t in target_texts]

# Target for adversarial attack
target_text = "wall"
target_idx_adv = all_texts.index(target_text)



# # # pil_img = Image.open("demo.jpg")
# pil_img = Image.open("F:/All codes/CLIP_Surgery/flick30Images/flickr30k_images/459814265.jpg") #dog
# # # pil_img = Image.open("F:/All codes/CLIP_Surgery/image-net-val/ILSVRC2012_val_00002023.JPEG") #bird
# # # pil_img = Image.open("F:/All codes/CLIP_Surgery/coco-unlabeled2017/unlabeled2017/000000002115.jpg")
# # # pil_img = Image.open("F:/All codes/CLIP_Surgery/coco-unlabeled2017/unlabeled2017/000000002069.jpg")
# # # pil_img = Image.open("F:/All codes/CLIP_Surgery/coco-unlabeled2017/unlabeled2017/000000002651.jpg")

# # pil_img = Image.open("F:/All codes/CLIP_Surgery/coco-unlabeled2017/unlabeled2017/000000003599.jpg") #cat



# cv2_img = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
# image = preprocess(pil_img).unsqueeze(0).to(device)  # Shape: [1, 3, 224, 224]

# # --- Texts ---
# all_texts = ['airplane', 'bag', 'bed', 'bedclothes', 'bench', 'bicycle', 'bird', 'boat', 'book', 'bottle', 
#              'building', 'bus', 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', 'computer', 'cow', 'cup', 'curtain', 
#              'dog', 'door', 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 'horse', 'keyboard', 'light', 'motorbike', 
#              'mountain', 'mouse', 'person', 'plate', 'platform', 'potted plant', 'road', 'rock', 'sheep', 'shelves', 'sidewalk', 
#              'sign', 'sky', 'snow', 'sofa', 'table', 'track', 'train', 'tree', 'truck', 'tv monitor', 'wall', 'water', 'window', 
#              'wood', 'staire', 'wall', 'desk','cards', 'vespa', 'bear','banana','piano']

# target_texts = ['dog']
# target_idx = [all_texts.index(t) for t in target_texts]
# #  target attack
# target_text = "ground"
# target_idx_adv = all_texts.index(target_text)



# Ensure consistent image size
img = pil_img.resize((224, 224))

# --- Step 1: Encode text ---
with torch.no_grad():
    text_features = clip.encode_text_with_prompt_ensemble(model, all_texts, device)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)
    print("Text feature dim:", text_features.shape)

###########################################################################################################
#                               ABLATION MODULE FOR HEATMAP + METRICS
###########################################################################################################

def compute_sim_maps(model, image, text_features):
    """Returns z_clean, z_adv, sim_map_clean, sim_map_adv (4D)."""
    with torch.no_grad():
        z = encode_and_norm(model, image)
        sim_map = build_similarity_map(z, text_features)
    return z, sim_map


def compute_IoU(sim_clean, sim_adv, target_idx, topk=50):
    """IoU of top-k patches."""
    flat_c = sim_clean[0,:,:,target_idx].reshape(-1)
    flat_a = sim_adv [0,:,:,target_idx].reshape(-1)

    idx_c = flat_c.topk(topk).indices
    idx_a = flat_a.topk(topk).indices

    mask_c = torch.zeros_like(flat_c, dtype=torch.bool)
    mask_a = torch.zeros_like(flat_a, dtype=torch.bool)
    mask_c[idx_c] = True
    mask_a[idx_a] = True

    inter = (mask_c & mask_a).sum().item()
    union = (mask_c | mask_a).sum().item()
    return inter / union if union > 0 else 0.0


def plot_ablation_heatmaps(
    name, cv2_img,
    sim_c, sim_a,
    H=224, W=224
    ):
    """One-row heatmap visualization for ablation."""

    def blend(sim):
        # Normalize heatmap
        sim = sim - sim.min()
        sim = sim / (sim.max() + 1e-8)
    
        # Resize heatmap up to image size
        H, W = cv2_img.shape[:2]
        sim_resized = cv2.resize(sim, (W, H))
    
        # Convert to color heatmap
        hm = (sim_resized * 255).astype(np.uint8)
        hm = cv2.applyColorMap(hm, cv2.COLORMAP_JET)
    
        # Blend with original cv2 image
        return cv2.cvtColor(
            (0.4 * cv2_img + 0.6 * hm).astype("uint8"),
            cv2.COLOR_BGR2RGB
        )


    clean_map  = blend(sim_c)
    adv_map    = blend(sim_a)
    diff_map   = blend(sim_a - sim_c)

    fig, ax = plt.subplots(1,3, figsize=(12,4))
    ax[0].imshow(clean_map); ax[0].set_title(f"{name} — Clean"); ax[0].axis('off')
    ax[1].imshow(adv_map);   ax[1].set_title(f"{name} — Adv");   ax[1].axis('off')
    ax[2].imshow(diff_map);  ax[2].set_title(f"{name} — Diff");  ax[2].axis('off')

    plt.tight_layout()
    plt.show()


def run_full_ablation_pipeline(
    image, model, text_features, target_idx, cv2_img,
    ablations
):
    results = []

    for name, params in ablations.items():

        print("\n===================================================")
        print(f"     Running Attack Ablation: {name}")
        print("===================================================")

        # --------------------------
        # 1) Run Attack
        # --------------------------
        image_adv, history = generate_adversarial_image(
            image=image.clone(),
            model=model,
            text_features=text_features,
            target_idx=target_idx,
            **params
        )

        # --------------------------
        # 2) Build heatmaps
        # --------------------------
        z_clean, sim_clean = compute_sim_maps(model, image,     text_features)
        z_adv,   sim_adv   = compute_sim_maps(model, image_adv, text_features)

        sim_clean_t = sim_clean[0,:,:,target_idx].detach().cpu().numpy()
        sim_adv_t   = sim_adv  [0,:,:,target_idx].detach().cpu().numpy()

        # --------------------------
        # 3) Metrics
        # --------------------------
        cos_cls = F.cosine_similarity(
            z_clean[:,0,:], z_adv[:,0,:]
        ).item()

        logits_c = z_clean[:,0,:] @ text_features.T
        logits_a = z_adv  [:,0,:] @ text_features.T

        p_c = torch.softmax(logits_c,1)
        p_a = torch.softmax(logits_a,1)
        max_delta = (p_a - p_c).abs().max().item()

        IoU = compute_IoU(sim_clean, sim_adv, target_idx)

        results.append({
            "Ablation": name,
            "FinalLoss": history["loss"][-1],
            "CosSim_CLS": cos_cls,
            "MaxΔProb": max_delta,
            "IoU_Topk": IoU,
            "TargetSim": float(history["sim_target_mean"][-1])
        })

        # --------------------------
        # 4) Plot heatmaps
        # --------------------------
        plot_ablation_heatmaps(
            name,
            cv2_img,
            sim_clean_t,
            sim_adv_t
        )

    df = pd.DataFrame(results)
    print("\n====================== FINAL ABLATION SUMMARY ======================")
    print(df)
    return df

# --------------------------
# Define ablations
# --------------------------

ablations = {
    "full_loss": {
        "steps": 200,
        "lambda_pred": 0.5,
        "lambda_entropy": 0.5,
        "lambda_margin": 0.5,
        "L_xai_scale": 5.0
    },
    "xai_only": {
        "steps": 200,
        "lambda_pred": 0.0,
        "lambda_entropy": 0.5,
        "lambda_margin": 0.0,
        "L_xai_scale": 5.0
    },
    "pred_only": {
        "steps": 200,
        "lambda_pred": 0.5,
        "lambda_entropy": 0.0,
        "lambda_margin": 0.5,
        "L_xai_scale": 0.0
    },
}

df_ablation = run_full_ablation_pipeline(
    image=image,
    model=model,
    text_features=text_features,
    target_idx=target_idx_adv,
    cv2_img=cv2_img,
    ablations=ablations
)

# Plot numerical comparison
df_ablation.plot(
    x="Ablation",
    y=["IoU_Topk","CosSim_CLS","MaxΔProb","TargetSim"],
    kind="bar", figsize=(10,6),
    title="Loss Component Ablation — Heatmap Attack Behavior"
)
plt.tight_layout()
plt.show()
