# =============================================================
# adaptive_attack.py — Adaptive XAI-aware adversarial generator
# Author: ChatGPT (Clean version)
# =============================================================
import torch
import torch.nn.functional as F
import clip


# -------------------------------------------------------------
# Utility: L∞ projection in CLIP normalized space
# -------------------------------------------------------------
def _linf_project(delta, eps_norm):
    return torch.max(torch.min(delta, eps_norm), -eps_norm)


# -------------------------------------------------------------
# Utility: token extraction (CLS + patches)
# -------------------------------------------------------------
def _encode_tokens(model, x):
    z = model.encode_image(x)
    if z.dim() != 3:
        raise RuntimeError("Model must return token-level features [1, T, D].")
    z = z / (z.norm(dim=-1, keepdim=True) + 1e-6)
    cls = z[:, 0, :]
    patches = z[:, 1:, :]
    return cls, patches


# -------------------------------------------------------------
# Feature Surgery (Stage II)
# -------------------------------------------------------------
def _feature_surgery_tokens(z, text_features, temp=2.0):
    """
    z: [1,T,D]
    text_features: [N,D]
    Returns similarity map after feature-surgery postprocessing.
    """
    sim_all = clip.clip_feature_surgery(
        z,
        text_features,
        redundant_feats=None,
        t=temp
    )
    return sim_all         # shape [1,T,N]


# -------------------------------------------------------------
# Top-K patch selection (adaptive)
# -------------------------------------------------------------
def _select_topk(sim_p, K):
    """sim_p: [num_patches], returns indices"""
    K_eff = min(K, sim_p.shape[0])
    return sim_p.topk(K_eff, dim=0).indices


def differentiable_feature_surgery(image_tokens, text_features, t=2.0):
    """
    image_tokens: [1, T, D]  (CLS + patches)
    text_features: [N, D]
    Returns: sim_all [1, T, N]
    """
    # normalize
    I = image_tokens / (image_tokens.norm(dim=-1, keepdim=True) + 1e-6)
    T = text_features / (text_features.norm(dim=-1, keepdim=True) + 1e-6)

    # CLS token → [1,D]
    CLS = I[:, 0, :]       # [1,D]

    # Raw probability of class presence
    # prob[n] = CLS · T[n]
    prob = (CLS @ T.T) * t              # [1,N]
    prob = prob.softmax(dim=-1)         # soft class weights

    # Normalize importance weights
    w = prob / prob.mean(dim=-1, keepdim=True)   # [1,N]

    # Reweight text features
    T_prime = T * w.T                     # [N,D]

    # Patch similarity
    sim = I @ T_prime.T                  # [1,T,N]

    return sim


# -------------------------------------------------------------
# Token encoder
# -------------------------------------------------------------
def _encode_tokens(model, x):
    z = model.encode_image(x)
    if z.dim() != 3:
        raise RuntimeError("Model must return token-level output [1,T,D]")
    z = z / (z.norm(dim=-1, keepdim=True) + 1e-6)
    CLS = z[:, 0, :]
    patches = z[:, 1:, :]
    return CLS, patches, z


# -------------------------------------------------------------
# L∞ projection
# -------------------------------------------------------------
def _linf(delta, eps):
    return torch.max(torch.min(delta, eps), -eps)

import torch
import torch.nn.functional as F


# -------------------------------------------------------
# Differentiable Stage-II Feature Surgery (FS)
# -------------------------------------------------------
def differentiable_FS(tokens, text_features, t=2.0):
    """
    tokens: [1, T, D]  CLS+patches
    text_features: [N, D]
    returns: sim_all [1, T, N]
    """

    I = tokens / (tokens.norm(dim=-1, keepdim=True) + 1e-6)
    T = text_features / (text_features.norm(dim=-1, keepdim=True) + 1e-6)

    CLS = I[:, 0, :]             # [1,D]
    P = (CLS @ T.T) * t          # class correlation
    W = P.softmax(dim=-1)        # presence weight
    W = W / W.mean()             # normalize

    T_prime = T * W.T            # reweighted text features
    sim = I @ T_prime.T          # patch × class map

    return sim                   # [1,T,N]


# -------------------------------------------------------
# Encode tokens from NORMAL CLIP (differentiable)
# -------------------------------------------------------
def encode_tokens_base(model, x):
    z = model.encode_image(x)            # [1,T,D] for ViT, [1,D] for RN50
    if z.dim() == 2:
        raise RuntimeError("RN50 does NOT support token attack. Use ViT-B/16 / ViT-L/14.")
    z = z / (z.norm(dim=-1, keepdim=True) + 1e-6)
    CLS = z[:, 0, :]
    patches = z[:, 1:, :]
    return CLS, patches, z


# -------------------------------------------------------
# Adaptive Attacker (FaithShield-aware)
# -------------------------------------------------------
def generate_adaptive_attack(
    model,
    model_surgery,
    image,
    text_features,
    target_idx,
    steps=300,
    adam_lr=5e-3,
    K_patches=40,
    l0_ratio=0.01,
    eps_image=4/255,
    surgery_temp=2.0,
    lambda_xai=10.0,
    lambda_pred=1.0,
    lambda_entropy=0.3,
    lambda_margin=0.3,
    verbose=True
):
    device = image.device

    # ---------------- CLIP normalization constants ----------------
    MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=device).view(1,3,1,1)
    STD  = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device).view(1,3,1,1)
    EPS = (torch.tensor([eps_image]*3, device=device) / STD.view(3)).view(1,3,1,1)

    # ---------------- Clean CLS (preserve prediction) --------------
    with torch.no_grad():
        z_clean = model.encode_image(image)
        z_clean = z_clean / (z_clean.norm(dim=-1, keepdim=True) + 1e-6)
        CLS_clean = z_clean[:, 0, :]
        logits_clean = CLS_clean @ text_features.T
        y_star = int(logits_clean.argmax().item())

    # ---------------- delta initialization -------------------------
    delta = torch.zeros_like(image, requires_grad=True)
    opt = torch.optim.Adam([delta], lr=adam_lr)

    total_pixels = delta.numel()
    k_pixels = max(1, int(total_pixels * l0_ratio))

    for step in range(steps):

        x = image + delta

        # ---------------- Differentiable gradient path ----------------
        CLS, patches, tokens = encode_tokens_base(model, x)

        # ---------------- FaithShield Stage II observer signal --------
        with torch.no_grad():
            tokens_s = model_surgery.encode_image(x)
            tokens_s = tokens_s / (tokens_s.norm(dim=-1, keepdim=True) + 1e-6)
            sim_s = differentiable_FS(tokens_s, text_features, t=surgery_temp)

        sim_p = sim_s[:, 1:, target_idx][0]

        # ---------------- Loss terms ---------------------------------
        K = min(K_patches, sim_p.shape[0])
        idx = sim_p.topk(K).indices
        L_xai_term = -sim_p[idx].mean()

        sim_norm = (sim_p - sim_p.min()) / (sim_p.max() - sim_p.min() + 1e-6)
        m = torch.softmax(sim_norm, dim=0)
        L_entropy = (m * m.log()).sum()

        logits = CLS @ text_features.T
        L_pred = -torch.log_softmax(logits, dim=-1)[0, y_star]

        logits_p = patches @ text_features.T
        tgt_scores = logits_p[:, target_idx]
        other_max = logits_p.max(dim=1).values
        L_margin = torch.relu(other_max - tgt_scores + 0.1).mean()

        loss = (
            lambda_xai * L_xai_term +
            lambda_entropy * L_entropy +
            lambda_pred * L_pred +
            lambda_margin * L_margin
        )

        # ---------------- Backward ----------------
        opt.zero_grad()
        loss.backward()
        opt.step()

        # ---------------- Projection ----------------
        with torch.no_grad():
            flat = delta.view(-1)
            _, idxL0 = flat.abs().topk(k_pixels)
            mask = torch.zeros_like(flat)
            mask[idxL0] = 1.0
            delta.data = (flat * mask).view_as(delta)
            delta.data = torch.max(torch.min(delta.data, EPS), -EPS)

        if verbose and step % 20 == 0:
            print(f"[Adaptive {step}/{steps}]  loss={loss.item():.4f}")

    adv = image + delta.detach()
    return adv, {"loss": float(loss.item())}
