# -*- coding: utf-8 -*-
"""
pruning_attack.py  (split panels)
Generates TWO files per K:
  - K{K}_z_curves.pdf        (Z vs pruning ratio)
  - K{K}_tradeoff.pdf        (Accuracy vs Z)
Also writes: prune_curve_summary.csv

Assumed layout:
  repo/
    src/...
    example/
      pruning_attack.py
      unntrusted_cifar10_results/   <-- checkpoints here (recursively)
"""

import os, sys, re, copy
from pathlib import Path
from typing import Dict, Tuple, Optional, List
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import matplotlib.pyplot as plt
import pandas as pd
from torch.utils.data import DataLoader

# ── Paths & imports ──────────────────────────────────────────────────────────
EXAMPLE_DIR = Path(__file__).resolve().parent         # .../example
ROOT = EXAMPLE_DIR.parent                              # repo root
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))                     # allow `from src...` imports

RESULTS_ROOT = EXAMPLE_DIR / "unntrusted_cifar10_results"
if not RESULTS_ROOT.exists():
    raise FileNotFoundError(f"Expected results dir not found: {RESULTS_ROOT}")

OUT_DIR = EXAMPLE_DIR / "prune_attack_figs"
OUT_DIR.mkdir(parents=True, exist_ok=True)

from src.ResNet import ResNet18
from src.data_utils import get_cifar10_transforms, get_cifar10_dataset
from src.utils import get_cosine_similarity_model

# ── Config ───────────────────────────────────────────────────────────────────
SEEDS  = [0, 1, 2]
K_LIST = [2, 4, 8, 16, 32, 64, 128]
PRUNE_RATIOS = [0.3, 0.5, 0.7, 0.9]  # attack strengths

# watermark detector: z = (cos - mean) / std (as in your training)
DETECTION_MEAN = 0.0
DETECTION_STD  = 0.01

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EVAL_BATCH = 128
EVAL_MAX_BATCHES = int(os.environ.get("EVAL_MAX_BATCHES", "0"))  # 0 = full eval

# ── Data (safe test loader) ─────────────────────────────────────────────────
transform_train, transform_test = get_cifar10_transforms()
train_dataset, val_dataset, test_dataset = get_cifar10_dataset(transform_train, transform_test)
test_loader = DataLoader(
    test_dataset,
    batch_size=EVAL_BATCH,
    shuffle=False,
    num_workers=0,
    pin_memory=False,
    persistent_workers=False,
)

# ── Find best models + keys (recursive) ──────────────────────────────────────
RE_BEST = re.compile(
    r"highest_validation_accuracy_model_"
    r"K(?P<K>\d+)_lr(?P<lr>[\d.]+)_c(?P<c>[\d.]+)_steps(?P<steps>\d+)_bs(?P<bs>\d+)_seed(?P<seed>\d+)\.pt$"
)
RE_KEY = re.compile(
    r"cifar10_flip_vectors_"
    r"K(?P<K>\d+)_lr(?P<lr>[\d.]+)_c(?P<c>[\d.]+)_steps(?P<steps>\d+)_bs(?P<bs>\d+)_seed(?P<seed>\d+)\.pt$"
)

def _rc(x: float, nd: int = 6) -> float:
    """Round c to a fixed precision for robust matching."""
    return float(f"{x:.{nd}f}")

def scan_files() -> Tuple[Dict[Tuple[int,int,float], str], Dict[Tuple[int,int,float], str]]:
    """
    Index best models and flip-vector files under RESULTS_ROOT.
    Keep the newest file on duplicates. Keys are (K, seed, c_rounded).
    """
    best_models: Dict[Tuple[int,int,float], str] = {}
    flip_keys:   Dict[Tuple[int,int,float], str] = {}
    for p in RESULTS_ROOT.rglob("*.pt"):
        name = p.name
        s = str(p)
        m = RE_BEST.match(name)
        if m:
            K = int(m["K"]); seed = int(m["seed"]); cval = _rc(float(m["c"]))
            key = (K, seed, cval)
            if key not in best_models or os.path.getmtime(s) > os.path.getmtime(best_models[key]):
                best_models[key] = s
            continue
        m = RE_KEY.match(name)
        if m:
            K = int(m["K"]); seed = int(m["seed"]); cval = _rc(float(m["c"]))
            key = (K, seed, cval)
            if key not in flip_keys or os.path.getmtime(s) > os.path.getmtime(flip_keys[key]):
                flip_keys[key] = s
            continue
    return best_models, flip_keys

BEST_INDEX, KEY_INDEX = scan_files()

def cs_by_K_from_index() -> Dict[int, List[float]]:
    """Auto-discover available c values per K from indexed files."""
    byk = defaultdict(set)
    for (k, seed, c) in BEST_INDEX.keys():
        byk[k].add(c)
    return {k: sorted(list(v)) for k, v in byk.items()}

def find_paths(K: int, seed: int, c_target: float) -> Tuple[Optional[str], Optional[str]]:
    """Exact c match first; else nearest c for (K, seed) within tolerance."""
    c_r = _rc(c_target)
    exact = (K, seed, c_r)
    if exact in BEST_INDEX and exact in KEY_INDEX:
        return BEST_INDEX[exact], KEY_INDEX[exact]
    cs = sorted({c for (k, s, c) in BEST_INDEX.keys() if k == K and s == seed})
    if not cs: return None, None
    nearest = min(cs, key=lambda x: abs(x - c_r))
    if abs(nearest - c_r) <= 1e-4:
        key = (K, seed, nearest)
        if key in BEST_INDEX and key in KEY_INDEX:
            return BEST_INDEX[key], KEY_INDEX[key]
    return None, None

# ── Pruning helpers ──────────────────────────────────────────────────────────
def _iter_prunable_params(model: nn.Module):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)) and hasattr(m, "weight") and m.weight is not None:
            yield (m, "weight")

def prune_global_unstructured(model: nn.Module, amount: float) -> nn.Module:
    """Global magnitude (unstructured) pruning across Conv/Linear weights."""
    m = copy.deepcopy(model).cpu()
    params = list(_iter_prunable_params(m))
    if params:
        prune.global_unstructured(params, pruning_method=prune.L1Unstructured, amount=amount)
        for mod, name in params:
            try: prune.remove(mod, name)
            except: pass
    return m.to(DEVICE)

def prune_structured_outchannel(model: nn.Module, amount: float) -> nn.Module:
    """Per-module structured pruning of OUT channels/features (LnStructured)."""
    m = copy.deepcopy(model).cpu()
    for mod in m.modules():
        if isinstance(mod, (nn.Conv2d, nn.Linear)) and hasattr(mod, "weight") and mod.weight is not None:
            prune.ln_structured(mod, name="weight", amount=amount, n=1, dim=0)
            try: prune.remove(mod, "weight")
            except: pass
    return m.to(DEVICE)

# ── Evaluate helpers ─────────────────────────────────────────────────────────
@torch.no_grad()
def evaluate_top1(model: nn.Module, loader) -> float:
    model.eval()
    correct = total = 0
    for i, (x, y) in enumerate(loader):
        x, y = x.to(DEVICE), y.to(DEVICE)
        pred = model(x).argmax(1)
        total += y.size(0)
        correct += (pred == y).sum().item()
        if EVAL_MAX_BATCHES and (i + 1) >= EVAL_MAX_BATCHES:
            break
    return 100.0 * correct / total if total else float("nan")

def z_score(model: nn.Module, flip_vectors: Dict[str, torch.Tensor]) -> float:
    cos = get_cosine_similarity_model(model, flip_vectors)
    return float((cos - DETECTION_MEAN) / DETECTION_STD)

# ── Aggregate across seeds for each (K,c) ────────────────────────────────────
def eval_for_ratio(model: nn.Module, flip_vectors, ratio: float, method: str):
    if method == "magnitude":
        m2 = prune_global_unstructured(model, amount=ratio)
    elif method == "structured":
        m2 = prune_structured_outchannel(model, amount=ratio)
    else:
        raise ValueError(method)
    return z_score(m2, flip_vectors), evaluate_top1(m2, test_loader)

def average_over_seeds(K: int, c: float):
    """
    Returns:
      baseline: (z0_mean, acc0_mean)
      curves: dict method -> list of (ratio, z_mean, acc_mean)
      n: number of seeds averaged
    """
    z0_all, a0_all = [], []
    per_ratio = {
        "magnitude": {r: {"z": [], "a": []} for r in PRUNE_RATIOS},
        "structured": {r: {"z": [], "a": []} for r in PRUNE_RATIOS},
    }

    for seed in SEEDS:
        model_path, key_path = find_paths(K, seed, c)
        if model_path is None or key_path is None:
            print(f"[MISS] K={K}, c={c}, seed={seed}")
            continue

        model = ResNet18().to(DEVICE)
        model.load_state_dict(torch.load(model_path, map_location=DEVICE))
        flip_vectors = torch.load(key_path, map_location=DEVICE)

        # baseline (no pruning)
        z0 = z_score(model, flip_vectors)
        a0 = evaluate_top1(model, test_loader)
        z0_all.append(z0); a0_all.append(a0)

        # attacks
        for r in PRUNE_RATIOS:
            z_m, a_m = eval_for_ratio(model, flip_vectors, r, "magnitude")
            per_ratio["magnitude"][r]["z"].append(z_m)
            per_ratio["magnitude"][r]["a"].append(a_m)

            z_s, a_s = eval_for_ratio(model, flip_vectors, r, "structured")
            per_ratio["structured"][r]["z"].append(z_s)
            per_ratio["structured"][r]["a"].append(a_s)

    mean = lambda x: float(np.mean(x)) if x else float("nan")
    baseline = (mean(z0_all), mean(a0_all))
    curves = {
        m: [(r, mean(per_ratio[m][r]["z"]), mean(per_ratio[m][r]["a"])) for r in PRUNE_RATIOS]
        for m in per_ratio.keys()
    }
    return baseline, curves, len(z0_all)

# ── Plotting (SPLIT — two files per K) ───────────────────────────────────────
# ── Plotting (SPLIT — two files per K) ───────────────────────────────────────
def plot_left_z_curves(K: int, results_for_c: Dict[float, dict]):
    """Left panel: Z vs pruning ratio (solid=magnitude, dashed=structured)."""
    from matplotlib.lines import Line2D
    
    cmap = plt.get_cmap("viridis")
    cs_sorted = sorted(results_for_c.keys())
    colors = {c: cmap(i / max(1, len(cs_sorted)-1)) for i, c in enumerate(cs_sorted)}

    fig, axL = plt.subplots(figsize=(12, 5))
    
    # Store for custom legend
    legend_elements = []
    
    for c in cs_sorted:
        col = colors[c]
        mag = results_for_c[c]["magnitude"]
        st  = results_for_c[c]["structured"]
        x_mag = [r for (r, _, _) in mag]
        y_mag = [z for (_, z, _) in mag]
        x_st  = [r for (r, _, _) in st]
        y_st  = [z for (_, z, _) in st]

        # Plot magnitude: solid line + filled circles
        axL.plot(x_mag, y_mag, marker="o", linestyle="-", color=col, linewidth=2,
                markerfacecolor=col, markeredgecolor='white', markeredgewidth=1, markersize=6)
        
        # Plot structured: dashed line + hollow squares
        axL.plot(x_st, y_st, marker="s", linestyle="--", color=col, linewidth=2,
                markerfacecolor='white', markeredgecolor=col, markeredgewidth=2, markersize=6)

        # Create custom legend entries with proper line styles
        legend_elements.append(Line2D([0], [0], color=col, linestyle='-', marker='o',
                                    markerfacecolor=col, markeredgecolor='white', 
                                    markeredgewidth=1, markersize=6, linewidth=2,
                                    label=f"C={c:.3f} (magnitude)"))
        legend_elements.append(Line2D([0], [0], color=col, linestyle='--', marker='s',
                                    markerfacecolor='white', markeredgecolor=col,
                                    markeredgewidth=2, markersize=6, linewidth=2,
                                    label=f"C={c:.3f} (structured)"))

    # Add threshold line
    axL.axhline(4.0, ls="--", color="red", linewidth=2)
    legend_elements.append(Line2D([0], [0], color='red', linestyle='--', linewidth=2,
                                 label="Detection Threshold (Z=4)"))
    
    axL.set_xlabel("Pruning Ratio")
    axL.set_ylabel("Watermark Z-Score")
    axL.set_title(f"Impact of Pruning on Watermark")
    axL.grid(True, alpha=0.25); axL.set_axisbelow(True)
    
    # Create legend with custom elements that will show line styles properly
    axL.legend(handles=legend_elements, fontsize=8, ncol=2, 
              bbox_to_anchor=(1.05, 1), loc='upper left')
    
    fig.tight_layout()
    fig.savefig(OUT_DIR / f"K{K}_z_curves.pdf", bbox_inches='tight')
    plt.close(fig)

def plot_right_tradeoff(K: int, results_for_c: Dict[float, dict]):
    """Right panel: Accuracy vs Z trade-off with ratio labels."""
    from matplotlib.lines import Line2D
    
    cmap = plt.get_cmap("viridis")
    cs_sorted = sorted(results_for_c.keys())
    colors = {c: cmap(i / max(1, len(cs_sorted)-1)) for i, c in enumerate(cs_sorted)}

    fig, axR = plt.subplots(figsize=(12, 5))
    
    # Store for custom legend
    legend_elements = []
    
    # Add threshold line first
    axR.axvline(4.0, ls="--", color="red", linewidth=2)
    legend_elements.append(Line2D([0], [0], color='red', linestyle='--', linewidth=2,
                                 label="Detection Threshold (Z=4)"))

    for c in cs_sorted:
        col = colors[c]
        z0, a0 = results_for_c[c]["baseline"]
        
        # Original point - large filled circle
        axR.scatter([z0], [a0], color=col, marker="o", s=80, edgecolor='black', linewidth=1)
        legend_elements.append(Line2D([0], [0], color=col, marker='o', linestyle='None',
                                    markerfacecolor=col, markeredgecolor='black',
                                    markeredgewidth=1, markersize=8,
                                    label=f"C={c:.3f} (original)"))

        # magnitude path: solid line + triangle markers
        mag = results_for_c[c]["magnitude"]
        z_m = [z for (_, z, _) in mag]; a_m = [a for (_, _, a) in mag]
        axR.plot(z_m, a_m, color=col, linestyle="-", alpha=0.9, linewidth=2)
        legend_elements.append(Line2D([0], [0], color=col, linestyle='-', linewidth=2,
                                    label=f"C={c:.3f} (magnitude)"))
        
        for (r, z, a) in mag:
            axR.scatter([z],[a], color=col, marker="^", s=25)
            axR.annotate(f"{r}", (z, a), fontsize=7)

        # structured path: dashed line + square markers  
        st = results_for_c[c]["structured"]
        z_s = [z for (_, z, _) in st]; a_s = [a for (_, _, a) in st]
        axR.plot(z_s, a_s, color=col, linestyle="--", alpha=0.9, linewidth=2)
        legend_elements.append(Line2D([0], [0], color=col, linestyle='--', linewidth=2,
                                    label=f"C={c:.3f} (structured)"))
        
        for (r, z, a) in st:
            axR.scatter([z],[a], color=col, marker="s", s=20)
            axR.annotate(f"{r}", (z, a), fontsize=7)

    axR.set_xlabel("Watermark Z-Score")
    axR.set_ylabel("Test Accuracy (%)")
    axR.set_title(f"Trade-off: Accuracy vs Watermark Detectability After Pruning (K={K})")
    axR.grid(True, alpha=0.25); axR.set_axisbelow(True)
    
    # Create legend with custom elements
    axR.legend(handles=legend_elements, fontsize=7, ncol=2,
              bbox_to_anchor=(1.05, 1), loc='upper left')
    
    fig.tight_layout()
    fig.savefig(OUT_DIR / f"K{K}_tradeoff.pdf", bbox_inches='tight')
    plt.close(fig)

# ── Main ─────────────────────────────────────────────────────────────────────
def main():
    print(f"Device: {DEVICE}")

    # Auto-discover which c's exist per K
    available_cs = cs_by_K_from_index()
    for k in sorted(available_cs.keys()):
        print(f"[INFO] K={k} -> c values found: {available_cs[k]}")

    rows = []
    for K in K_LIST:
        c_list = available_cs.get(K, [])
        print("c",c_list[1:])
        if not c_list:
            print(f"[WARN] No checkpoints indexed for K={K}; skipping.")
            continue

        results_for_c = {}
        for c in c_list[1:]:
            baseline, curves, n = average_over_seeds(K, c)
            if n == 0:
                print(f"[WARN] No seeds found for K={K}, c={c} — skipping.")
                continue

            results_for_c[c] = {
                "baseline": baseline,
                "magnitude": curves["magnitude"],
                "structured": curves["structured"],
            }

            # accumulate CSV rows
            z0, a0 = baseline
            rows.append({"K":K, "c":c, "ratio":0.0, "method":"original", "z":z0, "acc":a0})
            for method, series in curves.items():
                for r, z, a in series:
                    rows.append({"K":K, "c":c, "ratio":r, "method":method, "z":z, "acc":a})

        if results_for_c:
            plot_left_z_curves(K, results_for_c)
            plot_right_tradeoff(K, results_for_c)
        else:
            print(f"[WARN] No plottable data for K={K}; no figures generated.")

    if rows:
        pd.DataFrame(rows).to_csv(OUT_DIR / "prune_curve_summary.csv", index=False)
        print(f"Saved CSV summary → {OUT_DIR / 'prune_curve_summary.csv'}")

    print(f"Done. Split figures saved in {OUT_DIR} (2 per K).")

if __name__ == "__main__":
    main()