# -*- coding: utf-8 -*-
from __future__ import annotations
import os
import argparse
import time
from pathlib import Path
from typing import Tuple
from collections import deque

import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from diffusers import StableDiffusionPipeline
from lpips import LPIPS
from piq import DISTS

from utils import load_FR_models, get_target_test_images, alignment, read_img

# FAR thresholds for different FR backbones (0.1, 0.01, 0.001)
TH_DICT = {
    'ir152':       (0.094632, 0.166788, 0.227922),
    'irse50':      (0.144840, 0.241045, 0.312703),
    'facenet':     (0.256587, 0.409131, 0.591191),
    'mobile_face': (0.183635, 0.301611, 0.380878),
}

# ----------------------------- I/O utils -----------------------------

def load_and_maybe_crop(path: Path, device, do_crop: bool):
    """Read to [-1,1] BCHW; optionally crop with MTCNN bbox."""
    img = read_img(str(path), 0.5, 0.5, device)
    if do_crop:
        bb = alignment(Image.open(path).convert("RGB"))
        img_c = img[:, :, round(bb[1]):round(bb[3]), round(bb[0]):round(bb[2])]
        if img_c.shape[-1] and img_c.shape[-2]:
            img = img_c
    return img


def load_image(path: Path, device: torch.device) -> torch.Tensor:
    """PIL -> [-1,1] BCHW."""
    img = Image.open(path).convert("RGB")
    tensor = transforms.ToTensor()(img).unsqueeze(0).to(device) * 2 - 1
    return tensor


def save_tensor_as_img(t: torch.Tensor, path: Path):
    """Save [-1,1] tensor to image."""
    t01 = (t.clamp(-1, 1) + 1) / 2
    img = transforms.ToPILImage()(t01.squeeze(0).cpu().float())
    img.save(path)


def imagenet_norm(x: torch.Tensor) -> torch.Tensor:
    """ImageNet mean/std normalization for [-1,1] input."""
    mean = torch.tensor([0.485, 0.456, 0.406], device=x.device).view(1, 3, 1, 1)
    std  = torch.tensor([0.229, 0.224, 0.225], device=x.device).view(1, 3, 1, 1)
    x01 = (x + 1) / 2.0
    return (x01 - mean) / std


def fr_embed_from_minus1to1(x_m11: torch.Tensor, fr_model, fr_size: int) -> torch.Tensor:
    """Resize to FR input and extract embedding."""
    x_resized = F.interpolate(x_m11, size=fr_size, mode="bilinear", align_corners=False)
    return fr_model(x_resized)


def cosine_sim_m11(a_m11: torch.Tensor, target_emb: torch.Tensor, fr_model, fr_size: int) -> float:
    """Cosine similarity between image and target embedding."""
    a_emb = fr_embed_from_minus1to1(a_m11, fr_model, fr_size)
    return torch.cosine_similarity(a_emb, target_emb).item()


def crop_by_bb_m11(x: torch.Tensor, bb):
    """Crop BCHW [-1,1] tensor by bbox; no-op if bb is None."""
    if bb is None:
        return x
    x1, y1, x2, y2 = [int(round(v)) for v in bb]
    H, W = x.shape[-2], x.shape[-1]
    x1 = max(0, min(x1, W - 1))
    x2 = max(x1 + 1, min(x2, W))
    y1 = max(0, min(y1, H - 1))
    y2 = max(y1 + 1, min(y2, H))
    return x[..., y1:y2, x1:x2]

# --------------------- FR/target construction -----------------------

def build_fr_and_target(args, device: torch.device):
    """Return fr_model, fr_in_size, target_emb, abs_thr from TH_DICT and args."""
    fr_dict = load_FR_models(args, [args.test_model_name])
    if args.test_model_name not in fr_dict:
        raise ValueError(f"Unknown test_model_name: {args.test_model_name}")
    fr_in_size, fr_model = fr_dict[args.test_model_name]
    fr_model.eval()

    # Target image; cropping only affects FR view
    _, tgt_img_m11 = get_target_test_images(args.target_choice, device, args.MTCNN_cropping)

    with torch.no_grad():
        target_emb = fr_embed_from_minus1to1(tgt_img_m11, fr_model, fr_in_size)

    far_map = {0.1: 0, 0.01: 1, 0.001: 2}
    idx = far_map.get(args.far, None)
    if idx is None:
        raise ValueError("--far must be 0.1 / 0.01 / 0.001")
    if args.test_model_name not in TH_DICT:
        raise ValueError(f"No threshold for {args.test_model_name} in TH_DICT")
    abs_thr = TH_DICT[args.test_model_name][idx]

    return fr_model, fr_in_size, target_emb, abs_thr

# -------------------- Optimization with guardrail -------------------

def refine_image(
    pipe: StableDiffusionPipeline,
    adv_img_m11: torch.Tensor,     # [-1,1]
    orig_img_m11: torch.Tensor,    # [-1,1]
    base_sim: float,
    fr_size: int,
    fr_model,
    target_emb: torch.Tensor,
    fid_steps: int,
    fid_weight: float,
    abs_thr: float,
    fr_bb,                         # crop only for FR cos-sim
):
    device = adv_img_m11.device

    # Encode initial latent (VAE dtype); optimize in float32
    adv_for_vae = adv_img_m11.to(dtype=pipe.vae.dtype)
    with torch.no_grad():
        latent0_half = pipe.vae.encode(adv_for_vae).latent_dist.sample() * 0.18215
    latent0 = latent0_half.float()
    latent  = latent0.clone().detach().requires_grad_(True)

    lr = 3e-3
    opt = torch.optim.Adam([latent], lr=lr, betas=(0.9, 0.999), eps=1e-8)

    lpips_fn = LPIPS(net="vgg").to(device).eval()
    dists_fn = DISTS().to(device).eval()
    for fn in (lpips_fn, dists_fn):
        fn.requires_grad_(False)

    best_img, best_fid = None, float("inf")
    guard_on = base_sim >= abs_thr
    print(f"  guard={'ON' if guard_on else 'OFF'}  thr={abs_thr:.4f}")

    safe_stack: deque[torch.Tensor] = deque(maxlen=4)
    if guard_on:
        safe_stack.append(latent.detach().clone())
    last_safe_latent = latent.detach().clone()

    # small L2 on latent to discourage drift
    w_latent_l2 = 1e-3

    for step in range(fid_steps):
        # Decode (VAE dtype) then cast back to fp32
        latent_for_vae = (latent / 0.18215).to(dtype=pipe.vae.dtype)
        rec_half = pipe.vae.decode(latent_for_vae).sample
        if not torch.isfinite(rec_half).all():
            print(f"  step {step+1:>2}/{fid_steps}  decode non-finite → rollback & lr*=0.5")
            with torch.no_grad():
                latent.data.copy_(last_safe_latent)
            lr *= 0.5
            for g in opt.param_groups: g["lr"] = lr
            opt.zero_grad()
            continue

        rec = rec_half.float()

        # Cos-sim on cropped FR view only
        with torch.no_grad():
            rec_view = crop_by_bb_m11(rec, fr_bb)
            sim = cosine_sim_m11(rec_view, target_emb, fr_model, fr_size)

        # If initial sim < thr with guard, keep original adv
        if guard_on and step == 0 and sim < abs_thr:
            print(f"  step 1/{fid_steps}  recon sim {sim:.4f} < thr {abs_thr:.4f} → keep original adv")
            x01_adv  = (adv_img_m11.clamp(-1, 1) + 1) / 2.0
            x01_orig = (orig_img_m11.clamp(-1, 1) + 1) / 2.0
            if x01_adv.shape[-2:] != x01_orig.shape[-2:]:
                x01_orig = F.interpolate(x01_orig, size=x01_adv.shape[-2:], mode="bilinear", align_corners=False)
            base_fid = float(lpips_fn(x01_adv, x01_orig).mean().detach())
            return adv_img_m11.detach().clone(), base_fid

        # If sim drops below thr mid-optimization, rollback
        if guard_on and sim < abs_thr:
            if len(safe_stack) >= 2:
                rollback_to = safe_stack[-2]
            elif len(safe_stack) == 1:
                rollback_to = safe_stack[-1]
            else:
                rollback_to = last_safe_latent
            print(f"  step {step+1:>2}/{fid_steps}  sim={sim:.4f} < thr {abs_thr:.4f}  → rollback")
            with torch.no_grad():
                latent.data.copy_(rollback_to)
            opt.zero_grad()
            continue

        # LPIPS as proxy for FID
        x01_rec  = (rec.clamp(-1, 1) + 1) / 2.0
        x01_orig = (orig_img_m11.clamp(-1, 1) + 1) / 2.0
        if x01_rec.shape[-2:] != x01_orig.shape[-2:]:
            x01_orig = F.interpolate(x01_orig, size=x01_rec.shape[-2:], mode="bilinear", align_corners=False)
        fid_proxy = lpips_fn(x01_rec, x01_orig).mean()

        if not torch.isfinite(fid_proxy):
            print(f"  step {step+1:>2}/{fid_steps}  fid_proxy non-finite → rollback & lr*=0.5")
            with torch.no_grad():
                latent.data.copy_(last_safe_latent)
            lr *= 0.5
            for g in opt.param_groups: g["lr"] = lr
            opt.zero_grad()
            continue

        # Total loss
        loss = fid_weight * fid_proxy + w_latent_l2 * (latent - latent0).pow(2).mean()

        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_([latent], max_norm=1.0)

        # NaN/inf guard on gradients
        any_bad_grad = any(
            (p.grad is not None) and (not torch.isfinite(p.grad).all())
            for p in [latent]
        )
        if any_bad_grad:
            print(f"  step {step+1:>2}/{fid_steps}  gradient non-finite → rollback & lr*=0.5")
            with torch.no_grad():
                latent.data.copy_(last_safe_latent)
            lr *= 0.5
            for g in opt.param_groups: g["lr"] = lr
            opt.zero_grad()
            continue

        opt.step()

        # Update safe point
        if (not guard_on) or (sim >= abs_thr):
            safe_stack.append(latent.detach().clone())
            last_safe_latent = safe_stack[-1]

        fp_scalar = float(fid_proxy.detach())
        if fp_scalar < best_fid:
            best_fid = fp_scalar
            best_img = rec.detach().clone()

        print(f"  step {step+1:>2}/{fid_steps}  proxy-FID={fp_scalar:.4f}  sim={sim:.4f}  lr={lr:g}")

    final_img = best_img if best_img is not None else rec.detach()
    return final_img, best_fid

# -------------------------- CLI / main loop -------------------------

def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--adv_dir", required=True)
    ap.add_argument("--orig_dir", required=True)
    ap.add_argument("--output_dir", required=True)
    ap.add_argument("--target_choice", required=True, help="1/2/3/4")
    ap.add_argument("--test_model_name", required=True, choices=["irse50", "ir152", "facenet", "mobile_face"])
    ap.add_argument("--fid_steps", type=int, default=70)
    ap.add_argument("--fid_weight", type=float, default=1.0)
    ap.add_argument("--device", default="cuda")
    # Stable Diffusion path
    ap.add_argument("--sd_path", default="../stable-diffusion-2-base")
    # FR-view cropping switch (一致 with tests)
    ap.add_argument("--MTCNN_cropping", dest="MTCNN_cropping", action="store_true",
                    help="Use MTCNN crop for target & FR view (cos-sim only)")
    ap.add_argument("--mtcnn_crop", dest="MTCNN_cropping", action="store_true",
                    help="Alias of --MTCNN_cropping")
    ap.set_defaults(MTCNN_cropping=False)
    # FAR bucket for threshold lookup
    ap.add_argument("--far", type=float, default=0.01, choices=[0.1, 0.01, 0.001],
                    help="FAR bucket for threshold")
    # Batch indices for formatting placeholders
    ap.add_argument("--batch_idx", default=None,
                    help="Comma-separated indices to expand {i} placeholders in args")
    return ap.parse_args()


def run_once(args):
    device = torch.device(args.device)
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    print("[Init] Loading Stable Diffusion…")
    pipe = StableDiffusionPipeline.from_pretrained(
        args.sd_path, torch_dtype=torch.float16
    ).to(device)
    pipe.safety_checker = None

    print("[Init] Building FR & target…")
    fr_model, fr_size, target_emb, abs_thr = build_fr_and_target(args, device)
    print(f"[Init] FR input size = {fr_size}, thr = {abs_thr:.6f} (model={args.test_model_name}, FAR={args.far})")

    start_cnt, end_cnt = 0, 0
    sim_log_path = Path(args.output_dir) / "sim_baseline.txt"
    time_log_path = Path(args.output_dir) / "opt_time.csv"

    with open(sim_log_path, "w", encoding="utf-8") as sim_file, \
         open(time_log_path, "w", encoding="utf-8") as time_file:

        time_file.write("filename,seconds\n")
        total_opt_time = 0.0
        n_imgs = 0

        adv_files = sorted([
            p for p in os.listdir(args.adv_dir)
            if p.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".webp"))
        ])

        for fname in adv_files:
            adv_path  = Path(args.adv_dir) / fname
            orig_path = Path(args.orig_dir) / fname
            if not orig_path.exists():
                print(f"[Skip] {fname}: original image missing")
                continue

            print(f"\n[Process] {fname}")
            adv_img_m11  = load_and_maybe_crop(adv_path, device, do_crop=False)
            orig_img_m11 = load_and_maybe_crop(orig_path, device, do_crop=False)
            fr_bb = alignment(Image.open(adv_path).convert("RGB")) if args.MTCNN_cropping else None

            base_sim = cosine_sim_m11(crop_by_bb_m11(adv_img_m11, fr_bb), target_emb, fr_model, fr_size)
            sim_file.write(f"{fname}\t{base_sim:.6f}\n")
            print(f"  baseline cos-sim = {base_sim:.4f} (thr {abs_thr:.4f})")
            if base_sim >= abs_thr:
                start_cnt += 1

            t0 = time.perf_counter()
            refined_m11, best_fid = refine_image(
                pipe=pipe,
                adv_img_m11=adv_img_m11,
                orig_img_m11=orig_img_m11,
                base_sim=base_sim,
                fr_size=fr_size,
                fr_model=fr_model,
                target_emb=target_emb,
                fid_steps=args.fid_steps,
                fid_weight=args.fid_weight,
                abs_thr=abs_thr,
                fr_bb=fr_bb,
            )
            elapsed = time.perf_counter() - t0
            print(f"  optimize time = {elapsed:.2f}s")
            time_file.write(f"{fname},{elapsed:.3f}\n")
            total_opt_time += elapsed
            n_imgs += 1

            save_tensor_as_img(refined_m11, Path(args.output_dir) / fname)
            final_sim = cosine_sim_m11(crop_by_bb_m11(refined_m11, fr_bb), target_emb, fr_model, fr_size)
            print(f"  final   cos-sim = {final_sim:.4f} (thr {abs_thr:.4f})  best proxy-FID = {best_fid:.4f}")
            if final_sim >= abs_thr:
                end_cnt += 1

        avg_time = (total_opt_time / n_imgs) if n_imgs else 0.0
        print(f"\nProcessed {n_imgs} images, total opt time {total_opt_time:.2f}s, avg {avg_time:.2f}s/img")
        print(f"Per-image times saved to: {time_log_path}")

    print(f"\nInitial PSR≥thr count: {start_cnt}")
    print(f"Final   PSR≥thr count: {end_cnt}")
    print("Done. Output dir:", args.output_dir)


def main():
    args = parse_args()

    # Single run
    if not args.batch_idx:
        run_once(args)
        return

    # Batch mode: expand {i} placeholders
    idx_list = [s.strip() for s in args.batch_idx.split(",") if s.strip()]
    print(f"[Batch] indices = {idx_list}")

    fmt_fields = ["adv_dir", "orig_dir", "output_dir", "target_choice", "sd_path"]

    for i in idx_list:
        args_i = argparse.Namespace(**vars(args))
        for f in fmt_fields:
            v = getattr(args_i, f, None)
            if isinstance(v, str) and "{i}" in v:
                setattr(args_i, f, v.format(i=i))

        # keep target_choice as str for downstream compatibility
        print(f"\n[Batch] Running i={i}")
        print(f"  adv_dir={args_i.adv_dir}")
        print(f"  orig_dir={args_i.orig_dir}")
        print(f"  output_dir={args_i.output_dir}")
        print(f"  target_choice={args_i.target_choice}")
        run_once(args_i)


if __name__ == "__main__":
    main()
