import os
import time
import logging
import torch

from utils.model_loader import load_models_from_folder
from utils.config import config
from utils.data_loader import get_cifar10_data_half

# ---- 3rd‑party attack lib ----------------------------------------------------
try:
    from torchattacks import DeepFool, CW
except ImportError as e:
    raise ImportError("Please install torchattacks: pip install torchattacks") from e

# ────────────────────────────── logger & device ───────────────────────────────
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ══════════════════════════════════════════════════════════════════════════════
#   Fingerprint‑extraction pipeline
# ══════════════════════════════════════════════════════════════════════════════

def extract_fingerprints(param_cfg, data_cfg):
    """Pipeline with per-anchor ε_logit at x*, per-anchor (m_bound, L_loc),
    and timing for each stage."""
    # ── Load protected model ─────────────────────────────────────────────────
    t_start = time.perf_counter()
    prot_model = load_models_from_folder(
        param_cfg["protected_model_dir"], device=device,
        num_classes=data_cfg["num_classes"]
    )[0].eval()
    logger.info("Loaded protected model 🛡️")
    t_prot = time.perf_counter()
    logger.info(f"Time for loading protected model: {t_prot - t_start:.2f}s")

    # ── Load surrogates ──────────────────────────────────────────────────────
    ind_models = load_models_from_folder(
        param_cfg["independent_models_dir"], device=device,
        num_classes=data_cfg["num_classes"]
    )
    for m in ind_models: m.eval()
    pir_models = load_models_from_folder(
        param_cfg["pirated_models_dir"], device=device,
        num_classes=data_cfg["num_classes"]
    )
    for m in pir_models: m.eval()
    t_surrogates = time.perf_counter()
    logger.info(f"Loaded {len(ind_models)} independent and {len(pir_models)} pirated models: {t_surrogates - t_prot:.2f}s")

    # ── Hyper‑parameters & attack ─────────────────────────────────────────────
    m_anchor    = param_cfg.get("m_min", 5.0)
    q_margin    = param_cfg.get("margin_lower_quantile", 0.5)
    q_lip       = param_cfg.get("lip_upper_quantile", 0.5)
    q_eps       = param_cfg.get("eps_upper_quantile", 1.0)
    attack_mtd  = param_cfg.get("attack_method", "deepfool")
    out_dir     = param_cfg.get("fingerprints_dir", "fingerprints")
    logger.info(f"Hyper-parameters: m_anchor: {m_anchor}, q_margin: {q_margin}, q_lip: {q_lip},"
                f"q_eps: {q_eps}, attack_mtd:{attack_mtd}")
    os.makedirs(out_dir, exist_ok=True)

    if attack_mtd == "deepfool":
        attack = DeepFool(prot_model, steps=param_cfg.get("deepfool_max_iter", 50))
    elif attack_mtd == "cw":
        attack = CW(
            prot_model,
            c=param_cfg.get("cw_c",1e-4),
            kappa=param_cfg.get("cw_kappa",0),
            steps=param_cfg.get("cw_steps",3000),
            lr=param_cfg.get("cw_lr",0.01)
        )
    else:
        raise ValueError(f"Unknown attack_method={attack_mtd}")
    t_attack = time.perf_counter()
    logger.info(f"Setup attack ({attack_mtd}): {t_attack - t_surrogates:.2f}s")

    # ═══════════ Stage 1: Anchor selection ══════════════════════════════════
    t1 = time.perf_counter()
    ds_test, _, dl_test, _ = get_cifar10_data_half("first")
    anchors=[]
    with torch.no_grad():
        for imgs,_ in dl_test:
            imgs=imgs.to(device)
            logits=prot_model(imgs)
            top2=logits.topk(2,dim=1).values
            margins=(top2[:,0]-top2[:,1]).cpu()
            preds=logits.argmax(1).cpu()
            for img,m,y in zip(imgs.cpu(), margins, preds):
                if m>=m_anchor:
                    anchors.append((img,int(y)))
    t2 = time.perf_counter()
    logger.info(f"Stage 1: selected {len(anchors)} anchors (m_anchor={m_anchor}), time: {t2 - t1:.2f}s")

    # ═══════════ Stage 2: Minimal-flip perturbations ════════════════════════
    t3 = time.perf_counter()
    X0=torch.stack([x for x,_ in anchors],dim=0).to(device)
    Y0=torch.tensor([y for _,y in anchors],device=device)
    adv=attack(X0,Y0)
    flips=(prot_model(adv).argmax(1)!=Y0)
    boundary_pts=[(anchors[i][0],anchors[i][1],(adv[i].cpu()-anchors[i][0]))
                  for i,f in enumerate(flips) if f]
    t4 = time.perf_counter()
    logger.info(f"Stage 2: found {len(boundary_pts)} boundary pts, time: {t4 - t3:.2f}s")

    # ═══════════ Stage 3: Safe-expansion with x*-based ε estimation ════════
    t5 = time.perf_counter()
    safe_pts=[]
    for idx,(x0_cpu,y0,delta_cpu) in enumerate(boundary_pts):
        x0,delta = x0_cpu.to(device), delta_cpu.to(device)
        norm_d = delta.view(-1).norm().item()
        q=x0+delta; q.requires_grad_(True)
        log_q=prot_model(q.unsqueeze(0)).squeeze(0)
        yq=log_q.argmax().item()
        g_q=log_q[yq]-log_q.masked_fill(torch.arange(log_q.size(0),device=device)==yq,-1e9).max()
        grad_q=torch.autograd.grad(g_q,q)[0]; c_g=grad_q.view(-1).norm().item(); q.requires_grad_(False)

        margins_ind, lips = [], []
        with torch.no_grad():
            for m in ind_models:
                l0=m(x0.unsqueeze(0)).squeeze(0)
                lo=l0[y0].item(); ro=l0.masked_fill(torch.arange(l0.size(0),device=device)==y0,-1e9).max().item()
                margins_ind.append(lo-ro)
                lq=m(q.unsqueeze(0)).squeeze(0)
                lips.append((lq-l0).view(-1).norm().item()/max(norm_d,1e-12))
        m_bound=float(torch.tensor(margins_ind).quantile(torch.tensor(q_margin)))
        L_loc=float(torch.tensor(lips).quantile(torch.tensor(q_lip)))
        tau_high=m_bound/(2*L_loc*norm_d)
        # === Grid search on τ ∈ [1, τ_high] ===
        N_grid = 500
        EPS = 1e-12

        # 1) infeasible interval
        if tau_high < 1.0:
            # logger.info(f"[REJ {idx}] τ_high={tau_high:.4f} < 1.0 → discard")
            continue

        # 2) uniform grid
        tau_grid = torch.linspace(1.0, float(tau_high), steps=N_grid)

        best_tau = None
        best_score = -float("inf")
        best_x_star = None
        best_y_star = None
        best_tau_low = None

        for tau_cand in tau_grid.tolist():
            x_star_cand = (x0 + tau_cand * delta).detach()

            # calculate pirated shift eps_logit(x*)
            with torch.no_grad():
                base = prot_model(x_star_cand.unsqueeze(0)).squeeze(0)
                eps_list = [(pm(x_star_cand.unsqueeze(0)).squeeze(0) - base).abs().max().item()
                            for pm in pir_models]

            if len(eps_list) == 0:
                continue

            eps_star = float(torch.tensor(eps_list).quantile(torch.tensor(q_eps)))
            denom = max(c_g * norm_d, EPS)
            tau_low_cand = 1.0 + (2.0 * eps_star) / denom

            # feasible tau：τ ≥ τ_low(τ)
            if tau_cand + 1e-9 < tau_low_cand:
                continue

            # score：min( τ - τ_low(τ), τ_high - τ )
            gap_left = tau_cand - tau_low_cand
            gap_right = tau_high - tau_cand
            score = min(gap_left, gap_right)

            if score >= 0 and score > best_score:
                best_score = score
                best_tau = tau_cand
                best_x_star = x_star_cand
                with torch.no_grad():
                    best_y_star = prot_model(x_star_cand.unsqueeze(0)).argmax(1).item()
                best_tau_low = tau_low_cand

        # 3)
        if best_tau is None:
            logger.info(f"[REJ {idx}] No feasible τ in [1, τ_high={tau_high:.4f}]")
            continue

        safe_pts.append((best_x_star.cpu(), best_y_star))
        logger.info(
            f"[ACC {idx}] ||δ||={norm_d:.4f}, c_g={c_g:.4f}, "
            f"m_bound={m_bound:.4f}, L_loc={L_loc:.4f}, "
            f"τ_low*={best_tau_low:.4f}, τ*={best_tau:.4f}, τ_high={tau_high:.4f}, "
            f"score={best_score:.4f}"
        )
        # === end grid search ===

    t6 = time.perf_counter()
    logger.info(f"Stage 3: expanded to {len(safe_pts)} safe pts, time: {t6 - t5:.2f}s")

    # ═══════════ Stage 4: Save ────────────────────────────────────────────
    t7 = time.perf_counter()
    fps=torch.stack([p for p,_ in safe_pts]) if safe_pts else torch.empty((0,*x0_cpu.shape))
    labs=torch.tensor([l for _,l in safe_pts]) if safe_pts else torch.empty((0,),dtype=torch.long)
    torch.save({"fingerprints":fps,"labels":labs},os.path.join(out_dir,"TGF_fingerprint_set_43.pt"))
    t8 = time.perf_counter()
    logger.info(f"Stage 4: saved {len(safe_pts)} fingerprints, time: {t8 - t7:.2f}s")
    logger.info(f"Total extraction time: {t8 - t_start:.2f}s")


if __name__=="__main__":
    params=config.get("TGF")
    data=config.get("data")
    extract_fingerprints(params,data)
