#!/usr/bin/env python
# coding: utf-8
"""
VT‑Attack (CLIP)
2025‑04‑18
───────────────────────────────────────────────────────────────────────────────
Key speed‑ups
  • Automatic Mixed Precision (bf16) on CUDA
  • Optional `torch.compile` for encoders / projections
  • Fewer tensor copies & detach / requires_grad toggles
  • Constants and buffers created once, outside the PGD loop
  • TF32 enabled on Ampere+ GPUs
The attack logic and output are bit‑for‑bit identical in full FP32 mode.
───────────────────────────────────────────────────────────────────────────────
"""

import os, argparse, cv2, numpy as np, torch, torch.nn.functional as F
from glob import glob
from kmeans_pytorch import kmeans
from sklearn.metrics import silhouette_score
from transformers import CLIPModel, CLIPProcessor
from tqdm import tqdm
from natsort import natsorted
import json
from PIL import Image

# ───────────────────────────── Speed related switches ────────────────────────
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32       = True
torch.set_float32_matmul_precision('high')

# Configuration of vision encoder
config = {
    "crop_size": {"height": 336, "width": 336},
    "do_center_crop": True,
    "do_convert_rgb": True,
    "do_normalize": True,
    "do_resize": True,
    "image_mean": [0.48145466, 0.4578275, 0.40821073],
    "image_std": [0.26862954, 0.26130258, 0.27577711],
    "resample": 3,                   # PIL.Image.BICUBIC
    "size": {"shortest_edge": 336}
}

# PIL resampling map
RESAMPLE_MAP = {
    0: Image.NEAREST,
    1: Image.LANCZOS,
    2: Image.BILINEAR,
    3: Image.BICUBIC,
}

def transforms_resize_shortest(img: Image.Image, size: int, resample) -> Image.Image:
    w, h = img.size
    if (w <= h and w == size) or (h <= w and h == size):
        return img
    if w < h:
        new_w = size
        new_h = int(h * size / w)
    else:
        new_h = size
        new_w = int(w * size / h)
    return img.resize((new_w, new_h), resample=resample)


def center_crop(img: Image.Image, crop_h: int, crop_w: int) -> Image.Image:
    w, h = img.size
    left = max(0, (w - crop_w) // 2)
    top  = max(0, (h - crop_h) // 2)
    return img.crop((left, top, left + crop_w, top + crop_h))

def preprocess_np(image: np.ndarray, config: dict) -> np.ndarray:
    # 1) NumPy -> PIL
    img = Image.fromarray(image)

    # 2) RGB conversion
    if config.get("do_convert_rgb", False):
        img = img.convert("RGB")

    # 3) Resize
    if config.get("do_resize", False):
        shortest = config["size"]["shortest_edge"]
        resample = RESAMPLE_MAP.get(config.get("resample", 3), Image.BICUBIC)
        img = transforms_resize_shortest(img, shortest, resample)

    # 4) Center crop
    if config.get("do_center_crop", False):
        cs = config["crop_size"]
        img = center_crop(img, cs["height"], cs["width"])

    # 5) PIL -> NumPy (float32)
    arr = np.array(img)  # shape: (H, W, C), range 0–255

    return arr


def extract_image_list(jsonl_path):
    image_list = []
    with open(jsonl_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            image_list.append(data["image"])
    return image_list

# ────────────────────────────── Helper Modules ───────────────────────────────
class Denormalize(torch.nn.Module):
    def __init__(self, mean, std, dtype, device):
        super().__init__()
        self.register_buffer("mean", torch.tensor(mean, dtype=dtype, device=device).view(-1, 1, 1))
        self.register_buffer("std",  torch.tensor(std,  dtype=dtype, device=device).view(-1, 1, 1))

    def forward(self, x):                                   # (B,C,H,W) in CLIP‑norm space
        return x * self.std + self.mean                     # back to [0,1]


class Normalize(torch.nn.Module):
    def __init__(self, mean, std, dtype, device):
        super().__init__()
        self.register_buffer("mean", torch.tensor(mean, dtype=dtype, device=device).view(-1, 1, 1))
        self.register_buffer("std",  torch.tensor(std,  dtype=dtype, device=device).view(-1, 1, 1))

    def forward(self, x):
        return (x - self.mean) / self.std


# ───────────────────────────── PGD‑style attack ──────────────────────────────
def vt_attack(x_norm, img_enc,
              normalize, epsilon=8/255, alpha=1/255, steps=10,
              λ_feat=1.0, device='cpu', bf16=True):
    """
    Args
      x_norm  : (1,C,H,W) CLIP‑normalized image (float)
      epsilon : L‑∞ radius   (pixel space, 0‑1 scale)
      alpha   : PGD step size (pixel space, 0‑1 scale)
    Return
      adversarial image in **normalized** space (same dtype as x_norm)
    """
    dtype   = x_norm.dtype
    x_norm  = x_norm.to(device).detach()
    initial_noise = torch.empty_like(x_norm).uniform_(-epsilon * 255, epsilon * 255)
    x_adv   = (x_norm.clone() + initial_noise).requires_grad_(True)         # (requires_grad set ONCE)

    # ── constant pre‑computations ────────────────────────────────────────────
    with torch.no_grad():
        min_norm = normalize(torch.zeros_like(x_norm))     # bounds in norm‑space
        max_norm = normalize(torch.ones_like(x_norm))

    std         = normalize.std.view(1, -1, 1, 1)
    eps_norm    = (epsilon / std).to(dtype)
    alpha_norm  = (alpha   / std).to(dtype)

    # clean visual features (once)
    with torch.no_grad():
        v_clean      = img_enc(x_norm).last_hidden_state
        cls_clean, tok_clean = v_clean[:, 0], v_clean[:, 1:]
            
    autocast_on = (device == 'cuda' and bf16)
    for _ in range(steps):
        with torch.amp.autocast('cuda', enabled=autocast_on):
            v_adv       = img_enc(x_adv).last_hidden_state

            # (1) token feature loss
            l_feat = F.mse_loss(v_adv[:, 1:], tok_clean)

            loss   = l_feat * λ_feat

        x_adv.grad = None
        loss.backward()

        # PGD update (sign(∇))
        grad_sign   = x_adv.grad.sign()
        x_adv.data.add_(alpha_norm * grad_sign)

        # ε‑ball projection (+ clip to [0,1] in norm‑space)
        delta       = (x_adv - x_norm).clamp(-eps_norm, eps_norm)
        x_adv.data  = (x_norm + delta).clamp_(min=min_norm, max=max_norm)

    return x_adv.detach()                              # still normalized


# ─────────────────────────────────── Main ────────────────────────────────────
def run(args):
    dev  = "cuda" if torch.cuda.is_available() else "cpu"
    dtype= torch.bfloat16 if (dev == "cuda" and args.bf16) else torch.float32
    os.makedirs(args.output, exist_ok=True)
    print(f"[device] {dev} ({dtype})")

    # model & processor
    model = CLIPModel.from_pretrained(args.model, torch_dtype=dtype).to(dev).eval()
    proc          = CLIPProcessor.from_pretrained(args.model)
    mean, std     = proc.image_processor.image_mean, proc.image_processor.image_std
    print(f"[norm] mean={mean}, std={std}")

    normalize     = Normalize(mean, std, dtype, dev)
    denormalize   = Denormalize(mean, std, dtype, dev)

    # image list
    if args.dataset == "chair":
        imgs = [args.data] if args.data.lower().endswith(".jpg") else sorted(glob(os.path.join(args.data, "*.jpg")))
        imgs = natsorted(imgs)
    elif args.dataset == "pope_rand":
        img_name_lst = extract_image_list("dataset/coco/pope/coco_pope_random.json")
        img_name_lst = list(set(img_name_lst))
        imgs = [os.path.join(args.data, name) for name in img_name_lst]
        imgs = natsorted(imgs)
    elif args.dataset == "pope_pop":
        img_name_lst = extract_image_list("dataset/coco/pope/coco_pope_popular.json")
        img_name_lst = list(set(img_name_lst))
        imgs = [os.path.join(args.data, name) for name in img_name_lst]
        imgs = natsorted(imgs)
    elif args.dataset == "pope_adv":
        img_name_lst = extract_image_list("dataset/coco/pope/coco_pope_adversarial.json")
        img_name_lst = list(set(img_name_lst))
        imgs = [os.path.join(args.data, name) for name in img_name_lst]
        imgs = natsorted(imgs)
    elif args.dataset == "amber":
        img_name_lst = glob("dataset/AMBER/image/*.jpg", recursive=True)
        imgs = natsorted(img_name_lst)


    for pth in tqdm(imgs):
        name      = os.path.basename(pth)
        out_path  = os.path.join(args.output, name)   # save PNG
        if os.path.exists(out_path):
            print(f"✓ {out_path} exists — skip");  continue

        # ────── load & preprocess ──────
        bgr   = cv2.imread(pth)
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        rgb = preprocess_np(rgb, config)
        img_t = proc(images=rgb, return_tensors="pt").pixel_values.to(dev, dtype)   # already normalized

        # ────── attack ──────
        adv_norm = vt_attack(img_t,
                             img_enc=model.vision_model.eval(),
                             normalize=normalize,
                             epsilon=args.epsilon/255, alpha=args.alpha/255,
                             steps=args.steps, device=dev, bf16=args.bf16,
                             λ_feat=args.lambda_feat)
        
        # ────── sanity check ──────
        linf = ((denormalize(adv_norm) - denormalize(img_t)).abs().max() * 255).to(torch.uint8).item()
        assert linf <= args.epsilon + 1e-3, f"L‑∞ {linf:.2f} > ε={args.epsilon}"

        # ────── save (lossless) ──────
        adv_img = (denormalize(adv_norm).squeeze(0).permute(1, 2, 0).cpu().float().numpy())
        adv_img = np.clip(adv_img * 255 + 0.5, 0, 255).astype(np.uint8)           # round→uint8
        cv2.imwrite(out_path.replace('jpg', 'png'), cv2.cvtColor(adv_img, cv2.COLOR_RGB2BGR))
        print(f"★ saved → {out_path}   (max Δ={linf:.1f})")

# ────────────────────────────────── CLI ──────────────────────────────────────
if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--model",   default="openai/clip-vit-large-patch14-336")
    ap.add_argument("--data",    required=True,  help="jpg file or dir/*.jpg")
    ap.add_argument("--output",  required=True)
    ap.add_argument("--epsilon", type=int,   default=3,   help="pixel‑space ε (0‑255)")
    ap.add_argument("--alpha",   type=int,   default=1,   help="pixel‑space α (0‑255)")
    ap.add_argument("--steps",   type=int,   default=200, help="#PGD iterations")
    ap.add_argument("--lambda_feat", type=float, default=1.0)
    ap.add_argument("--dataset", type=str, choices=["chair", "pope_rand", "pope_pop", "pope_adv", "amber"])

    # speed toggles
    ap.add_argument("--bf16",    action="store_true",  default=True,  help="use bf16 AMP on CUDA")
    ap.add_argument("--no-bf16", dest="bf16", action="store_false")

    run(ap.parse_args())
