import os
import sys
from typing import Dict, List

import clip
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, wasserstein_distance, wilcoxon, ttest_rel
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode


device = "cuda" if torch.cuda.is_available() else "cpu"

MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=device).view(1,3,1,1)
STD  = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device).view(1,3,1,1)
NORM_LO = (0.0 - MEAN) / STD
NORM_HI = (1.0 - MEAN) / STD



def project_L0(delta, k):
    flat = delta.view(-1)
    if k < flat.numel():
        _, idx = torch.topk(flat.abs(), k, sorted=False)
        mask = torch.zeros_like(flat)
        mask[idx] = 1.0
        delta = (flat * mask).view_as(delta)
    return delta

def generate_adversarial_image(
    image, model, text_features, target_idx,
    preserve_idx=None,
    steps=100, step_size=1e-3,
    lambda_pred=0.2, lambda_entropy=0.5, lambda_margin=0.5,
    K_patches=20, l0_ratio=0.01, use_margin=False, margin=0.25,
    use_surgery=False, model_surgery=None, surgery_temp=2.0,
    log_every=5,
    # robust/backtracking hyperparams (tweakable)
    lambda_kld=0.1,         # KL to clean logits
    cos_floor=0.98,         # relax slightly to allow progress
    backtrack_fac=0.7,      # gentler shrink
    max_backtracks=2,
    # ADAM option (recommended)
    use_adam_on_delta=True,
    adam_lr=5e-3,
    L_xai_scale=5.0         # boost XAI term to make steering matter
    ):
    assert not use_surgery or (model_surgery is not None), "Provide model_surgery when use_surgery=True."
    device = image.device
    # choose epsilon in image space
    EPS_IMAGE = 4.0 / 255.0  # tweak 2/255..8/255
    
    # convert to normalized-space per-channel bounds
    EPS_NORM = (torch.tensor([EPS_IMAGE, EPS_IMAGE, EPS_IMAGE], device=image.device)
                / STD.view(3)).view(1,3,1,1)  # shape [1,3,1,1]
    
    def project_Linf(delta):
        return torch.max(torch.min(delta, EPS_NORM), -EPS_NORM)
    
    def denorm_clip(img_1x3HW):
        return (img_1x3HW * STD + MEAN).clamp(0, 1)


    with torch.no_grad():
        z_clean_tokens = model.encode_image(image)
        cls_clean = z_clean_tokens[:,0,:] if z_clean_tokens.dim()==3 else z_clean_tokens
        cls_clean = cls_clean / (cls_clean.norm(dim=-1, keepdim=True) + 1e-6)
        logits0 = cls_clean @ text_features.T
    
        if preserve_idx is None:
            y_star = int(logits0.argmax(dim=1).item())
            p_clean = torch.softmax(logits0.float(), dim=-1)
        else:
            # Force preservation of a specific class
            y_star = int(preserve_idx)
            p_clean = torch.zeros_like(logits0)
            p_clean[0, y_star] = 1.0  # a one-hot target distribution (or smooth it slightly if you prefer)


    x0 = image.detach().clone()
    # initialize delta var for Adam (small init)
    delta_var = torch.zeros_like(x0, device=device, requires_grad=True)
    if use_adam_on_delta:
        adam_opt = torch.optim.Adam([delta_var], lr=adam_lr)

    total_pixels = x0.numel()
    k_pixels = max(1, int(l0_ratio * total_pixels))
    t_target = text_features[target_idx].to(device)
    t_target = t_target / (t_target.norm() + 1e-6)

    history = {"loss": [], "L_xai": [], "L_pred": [], "cos_cls": [], "sim_target_mean": []}

    for step in range(steps):
        # build current x from delta_var (clamped into normalized bounds)
        with torch.no_grad():
            x_candidate = (x0 + delta_var).clamp_(NORM_LO, NORM_HI) if use_adam_on_delta else x0.clone()
        # forward
        z_tokens = model.encode_image(x_candidate)
        if z_tokens.dim() != 3:
            raise RuntimeError("Need token-level features for patch-based attack.")
        cls = z_tokens[:,0,:]
        patches = z_tokens[:,1:,:].squeeze(0)
        cls = cls / (cls.norm(dim=-1, keepdim=True) + 1e-6)
        patches = patches / (patches.norm(dim=-1, keepdim=True) + 1e-6)
        logits = cls @ text_features.T
        p_pred = torch.softmax(logits.float(), dim=-1)

     
        # XAI term
        if use_surgery:
            z_tokens_s = model_surgery.encode_image(x_candidate)
            z_tokens_s = z_tokens_s / (z_tokens_s.norm(dim=-1, keepdim=True) + 1e-6)
            sim_all = clip.clip_feature_surgery(z_tokens_s, text_features, redundant_feats=None, t=surgery_temp)
            sim_p = sim_all[0, 1:, target_idx]
        else:
            sim_p = patches @ t_target
        
        num_patches = sim_p.shape[0]
        K_eff = min(K_patches, num_patches)   # <= added
        topk_idx = sim_p.topk(K_eff, dim=0).indices
        L_xai = -sim_p[topk_idx].mean()
        L_xai_scaled = L_xai_scale * L_xai

        sim_norm = (sim_p - sim_p.min()) / (sim_p.max() - sim_p.min() + 1e-6)
        m = torch.softmax(sim_norm, dim=0)
        L_entropy = (m * (m.clamp_min(1e-8)).log()).sum()

        logits_p = patches @ text_features.T
        tgt_scores = logits_p[:, target_idx]
        other_max = logits_p.max(dim=1).values
        L_patch_margin = torch.relu(other_max - tgt_scores + 0.1).mean()

        if use_margin:
            y_logit = logits[0, y_star]
            other_max_global = logits[0, torch.arange(logits.shape[1], device=logits.device) != y_star].max()
            L_pred = torch.relu(other_max_global - y_logit + margin)
        else:
            logp = torch.log_softmax(logits.float(), dim=1)
            L_pred = -logp[0, y_star]

        L_kld = torch.nn.functional.kl_div(torch.log_softmax(logits.float(), dim=-1), p_clean, reduction="batchmean")
        # cosine anchor: keep CLS near clean representation
        L_cos = 1.0 - F.cosine_similarity(cls_clean, cls, dim=-1).mean()
        lambda_cos = 0.2  # small; tune 0.1–0.5

        # stronger emphasis on XAI
        loss = 20.0*L_xai + 0.5*lambda_entropy*L_entropy + 0.5*lambda_margin*L_patch_margin + 0.01*lambda_pred*L_pred




        # backward on the delta var (or x if not using adam)
        if use_adam_on_delta:
            adam_opt.zero_grad()
            # compute gradients wrt delta_var: do forward again with requires_grad on delta_var
            x_for_grad = (x0 + delta_var).clamp_(NORM_LO, NORM_HI)
            z_tokens = model.encode_image(x_for_grad)
            cls = z_tokens[:,0,:]
            patches = z_tokens[:,1:,:].squeeze(0)
            cls = cls / (cls.norm(dim=-1, keepdim=True) + 1e-6)
            patches = patches / (patches.norm(dim=-1, keepdim=True) + 1e-6)
            logits = cls @ text_features.T
            # XAI term
            if use_surgery:
                z_tokens_s = model_surgery.encode_image(x_candidate)
                z_tokens_s = z_tokens_s / (z_tokens_s.norm(dim=-1, keepdim=True) + 1e-6)
                sim_all = clip.clip_feature_surgery(z_tokens_s, text_features, redundant_feats=None, t=surgery_temp)
                sim_p = sim_all[0, 1:, target_idx]
            else:
                sim_p = patches @ t_target
            
            num_patches = sim_p.shape[0]
            K_eff = min(K_patches, num_patches)   # <= added
            topk_idx = sim_p.topk(K_eff, dim=0).indices
            L_xai = -sim_p[topk_idx].mean()
            L_xai_scaled = L_xai_scale * L_xai

            
            sim_norm = (sim_p - sim_p.min()) / (sim_p.max() - sim_p.min() + 1e-6)
            m = torch.softmax(sim_norm, dim=0)
            L_entropy = (m * (m.clamp_min(1e-8)).log()).sum()
            logits_p = patches @ text_features.T
            tgt_scores = logits_p[:, target_idx]
            other_max = logits_p.max(dim=1).values
            L_patch_margin = torch.relu(other_max - tgt_scores + 0.1).mean()
            if use_margin:
                y_logit = logits[0, y_star]
                other_max_global = logits[0, torch.arange(logits.shape[1], device=logits.device) != y_star].max()
                L_pred = torch.relu(other_max_global - y_logit + margin)
            else:
                logp = torch.log_softmax(logits.float(), dim=1)
                L_pred = -logp[0, y_star]
            L_kld = torch.nn.functional.kl_div(torch.log_softmax(logits.float(), dim=-1), p_clean, reduction="batchmean")
            loss_for_grad = L_xai_scaled + lambda_entropy * L_entropy + lambda_margin * L_patch_margin + lambda_pred * L_pred + lambda_kld * L_kld

            loss_for_grad.backward()
            adam_opt.step()

            # project delta to L0 and clamp to valid range
            with torch.no_grad():
                dv = delta_var.clone()
                # L0 projection (keep top-k)
                flat = dv.view(-1)
                if k_pixels < flat.numel():
                    _, idx = torch.topk(flat.abs(), k_pixels, sorted=False)
                    mask = torch.zeros_like(flat); mask[idx] = 1.0
                    dv = (flat * mask).view_as(dv)
            
                # L∞ projection (normalized space)
                dv = project_Linf(dv)
            
                # clamp final image bounds
                dv = (x0 + dv).clamp_(NORM_LO, NORM_HI) - x0
                delta_var.copy_(dv)
            
            x = (x0 + delta_var).clamp_(NORM_LO, NORM_HI).detach().requires_grad_(True)

        else:
            # fallback sign-step with gentler backtracking (kept for completeness)
            loss.backward()
            x_prev = (x0 + delta_var).clamp_(NORM_LO, NORM_HI).detach()
            step_size_local = step_size
            accepted = False
            for bt in range(max_backtracks):
                with torch.no_grad():
                    x_candidate = x_prev - step_size_local * (x_prev - x0).sign()
                    delta = x_candidate - x0
                    # L0 projection
                    flat = delta.view(-1)
                    if k_pixels < flat.numel():
                        _, idx = torch.topk(flat.abs(), k_pixels, sorted=False)
                        mask = torch.zeros_like(flat)
                        mask[idx] = 1.0
                        delta = (flat * mask).view_as(delta)
                    x_candidate = (x0 + delta).clamp_(NORM_LO, NORM_HI).detach()
                    z_now = model.encode_image(x_candidate)
                    cls_now = z_now[:,0,:] / (z_now[:,0,:].norm(dim=-1, keepdim=True) + 1e-6)
                    logits_now = cls_now @ text_features.T
                    pred_now = int(logits_now.argmax(dim=1).item())
                    cos_now = torch.nn.functional.cosine_similarity(cls_clean, cls_now).item()
                if (pred_now == y_star) and (cos_now >= cos_floor):
                    x = x_candidate.clone().requires_grad_(True)
                    accepted = True
                    break
                else:
                    step_size_local *= backtrack_fac
            if not accepted:
                x = x_prev.clone().requires_grad_(True)
                lambda_pred = min(10.0, lambda_pred * 1.05)

        # logging
        with torch.no_grad():
            zcls_now = model.encode_image(x)[:,0,:]
            zcls_now = zcls_now / (zcls_now.norm(dim=-1, keepdim=True) + 1e-6)
            cos_cls = F.cosine_similarity(cls_clean, zcls_now).item()
            sim_target_mean = float(sim_p.mean().cpu().numpy())

        history["loss"].append(float(loss.item()))
        history["L_xai"].append(float(L_xai.item()))
        history["L_pred"].append(float(L_pred.item()))
        history["cos_cls"].append(cos_cls)
        history["sim_target_mean"].append(sim_target_mean)

        if (step + 1) % log_every == 0 or step == 0:
            print(f"[step {step+1}/{steps}] loss={loss.item():.4f}  L_xai={L_xai.item():.4f}  L_pred={L_pred.item():.4f}  cos_cls={cos_cls:.4f}  sim_target_mean={sim_target_mean:.4f}")

    return (x.detach() if not use_adam_on_delta else (x0 + delta_var).clamp_(NORM_LO, NORM_HI).detach()), history
