# distill_grid_from_exports_with_ci.py
# Distillation attack over grids of K, C, and seeds using SAVED (pickled) datasets only.
# Produces per-combo plots/CSVs and an aggregate mean ± 95% CI across seeds.

import os
import re
import glob
import pickle
import sys
from datetime import datetime
from collections import defaultdict
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt

# ── Paths ────────────────────────────────────────────────────────────────────
EXAMPLE_DIR = Path(__file__).resolve().parent          # .../example
ROOT = EXAMPLE_DIR.parent                               # repo root
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))                      # allow `from src...` imports

from src.ResNet import ResNet18_CIFAR100
from src.utils import set_seed, get_device, get_cosine_similarity_model


# ===========================
# CONFIG (edit these)
# ===========================
# Where your teacher checkpoints + flip vectors live (the code will search these)
RESULTS_DIR_CANDIDATES = ["untrusted_cifar100_results", "unntrusted_cifar100_results", "."]

# Where your exported pickles live (created by your training script)
# (The code will look for a dir matching the seed, e.g., "...seed0", "...seed1", etc.)
EXPORT_DIR_PATTERNS = [
    "untrusted_cifar100_exported_data_cifar_100_seed{seed}",
    "unntrusted_cifar100_exported_data_cifar_100_seed{seed}",
]

# Grid to sweep
K_LIST    = [32]
C_LIST    = [0.025,0.05,0.075, 0.1]
SEED_LIST = [0, 1, 2]   # teacher/export seeds you trained with

# KD hyperparameters
PERCENTAGES = [1, 5, 10, 20]  # % of SAVED train dataset used for student training
EPOCHS      = 100
STUDENT_BS  = 128             # student training batch size
TEST_BS     = 128             # eval batch size on SAVED test dataset
LR          = 1e-3
TEMPERATURE = 3.0
ALPHA       = 0.5

# Distillation subset/initialization seed (fixed for fairness across combos)
GLOBAL_ATTACK_SEED = 1

# Watermark z-score calibration
WMARK_MEAN  = 0.0
WMARK_STD   = 0.0088
Z_SIG       = 4.0             # significance threshold for "detected"
EPS         = 1e-6

# Where to write outputs
OUT_BASE = "distill_grid_results_cifar_100"   # will contain subfolders per (K, c, seed)


# ===========================
# LOSS / HELPERS
# ===========================
class DistillationLoss(nn.Module):
    """KL(student||teacher @ T) + (1-α) * CE(hard labels)"""
    def __init__(self, temperature=3.0, alpha=0.7):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        T = self.temperature
        log_p_s = F.log_softmax(student_logits / T, dim=1)
        p_t    = F.softmax(teacher_logits / T, dim=1)
        soft_loss = F.kl_div(log_p_s, p_t, reduction="batchmean") * (T ** 2)
        hard_loss = self.ce(student_logits, labels)
        total = self.alpha * soft_loss + (1.0 - self.alpha) * hard_loss
        return total, soft_loss, hard_loss


@torch.no_grad()
def evaluate_model(model, loader, device):
    model.eval()
    correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        pred = out.argmax(dim=1)
        total += y.size(0)
        correct += (pred == y).sum().item()
    return 100.0 * correct / max(1, total)


def create_subset_loader_from_dataset(dataset, percent, batch_size, seed=42, num_workers=2):
    total = len(dataset)
    k = max(1, int(total * (percent / 100.0)))
    g = torch.Generator().manual_seed(seed)
    idx = torch.randperm(total, generator=g)[:k]
    subset = torch.utils.data.Subset(dataset, idx)
    loader = torch.utils.data.DataLoader(
        subset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )
    return loader, k


def load_exported_datasets(export_dir, override_test_bs=None):
    """
    Expects three pickles in export_dir:
      train_loader_data.pkl: {"dataset": ..., "batch_size": int, "shuffle": bool}
      validation_loader_data.pkl
      test_loader_data.pkl
    Returns: train_dataset, val_dataset, test_dataset, test_loader
    """
    def _load(pkl_name):
        path = os.path.join(export_dir, pkl_name)
        if not os.path.exists(path):
            raise FileNotFoundError(f"Missing exported file: {path}")
        with open(path, "rb") as f:
            return pickle.load(f)

    train_dict = _load("train_loader_data.pkl")
    val_dict   = _load("validation_loader_data.pkl")
    test_dict  = _load("test_loader_data.pkl")

    train_ds = train_dict["dataset"]
    val_ds   = val_dict["dataset"]
    test_ds  = test_dict["dataset"]

    test_bs_saved = test_dict.get("batch_size", 128)
    test_bs = override_test_bs if override_test_bs is not None else test_bs_saved
    test_loader = torch.utils.data.DataLoader(
        test_ds, batch_size=test_bs, shuffle=False, num_workers=2, pin_memory=True
    )
    return train_ds, val_ds, test_ds, test_loader


def newest(paths):
    return max(paths, key=os.path.getmtime) if paths else None


def find_existing_dir(candidates):
    for d in candidates:
        if os.path.isdir(d):
            return d
    return None


def find_export_dir(seed):
    candidates = [p.format(seed=seed) for p in EXPORT_DIR_PATTERNS]
    return find_existing_dir(candidates)


def find_teacher_path(K, c, seed):
    """
    Search through RESULT_DIR_CANDIDATES for a matching teacher checkpoint.
    Prefer 'highest_validation_accuracy_model_*', fall back to 'highest_accuracy_model_*',
    then final 'cifar100_fedavg_watermark_*'.
    """
    patterns_priority = [
        f"highest_validation_accuracy_model_K{K}_lr*_c{c}_steps*_bs*_seed{seed}.pt",
        f"highest_accuracy_model_K{K}_lr*_c{c}_steps*_bs*_seed{seed}.pt",
        f"cifar100_fedavg_watermark_K{K}_lr*_c{c}_steps*_bs*_seed{seed}.pt",
    ]
    for base in RESULTS_DIR_CANDIDATES:
        if not os.path.isdir(base):
            continue
        for pat in patterns_priority:
            paths = glob.glob(os.path.join(base, pat))
            if paths:
                return newest(paths)
    return None


def find_flip_vectors_path(K, c, seed, teacher_path=None):
    """
    Try to infer flip-vectors from teacher_path suffix; otherwise glob by pattern.
    """
    if teacher_path:
        d = os.path.dirname(teacher_path)
        base = os.path.basename(teacher_path)
        m = re.search(r"model_(K.+)\.pt$", base)
        if m:
            suffix = m.group(1)  # e.g., K16_lr0.001_c0.025_steps200_bs128_seed0
            cand = os.path.join(d, f"cifar100_flip_vectors_{suffix}.pt")
            if os.path.exists(cand):
                return cand

    patterns = [f"cifar100_flip_vectors_K{K}_lr*_c{c}_steps*_bs*_seed{seed}.pt"]
    dirs = [os.path.dirname(teacher_path)] if teacher_path else RESULTS_DIR_CANDIDATES
    for base in dirs:
        if not os.path.isdir(base):
            continue
        for pat in patterns:
            paths = glob.glob(os.path.join(base, pat))
            if paths:
                return newest(paths)
    return None


def run_single_percentage(
    teacher_model,
    flip_vectors,
    train_dataset,
    test_loader,
    percent,
    *,
    device,
    epochs=50,
    lr=1e-3,
    student_batch_size=128,
    temperature=3.0,
    alpha=0.7,
    seed=1,
    wmark_mean=0.0,
    wmark_std=0.0088
):
    print("\n" + "=" * 72)
    print(f"DISTILLATION @ {percent:.1f}% of SAVED train dataset")
    print("=" * 72)

    # Subset
    train_loader, subset_size = create_subset_loader_from_dataset(
        train_dataset, percent, student_batch_size, seed=seed
    )
    print(f"Using {subset_size:,} / {len(train_dataset):,} images")

    # Student
    student = ResNet18_CIFAR100().to(device)
    opt = torch.optim.Adam(student.parameters(), lr=lr)
    sched = torch.optim.lr_scheduler.StepLR(opt, step_size=max(1, epochs // 2), gamma=0.1)
    crit = DistillationLoss(temperature=temperature, alpha=alpha)

    # Baselines
    with torch.no_grad():
        teacher_acc = evaluate_model(teacher_model, test_loader, device)
        t_cos = get_cosine_similarity_model(teacher_model, flip_vectors)
        s_cos = get_cosine_similarity_model(student, flip_vectors)
    t_z = (t_cos - wmark_mean) / max(wmark_std, EPS)
    s_z = (s_cos - wmark_mean) / max(wmark_std, EPS)
    s_acc = evaluate_model(student, test_loader, device)

    print(f"Initial — Teacher: {teacher_acc:.2f}% acc | z={t_z:.2f}")
    print(f"Initial — Student: {s_acc:.2f}% acc | z={s_z:.2f}")

    epoch_hist = [0]
    acc_hist   = [s_acc]
    z_hist     = [s_z]

    for ep in tqdm(range(epochs), desc=f"KD {percent}%"):
        student.train()
        teacher_model.eval()

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                t_logits = teacher_model(x)
            opt.zero_grad(set_to_none=True)
            s_logits = student(x)
            total_loss, _, _ = crit(s_logits, t_logits, y)
            total_loss.backward()
            opt.step()

        sched.step()

        # log every 10 epochs + last
        if ((ep + 1) % 10 == 0) or (ep == epochs - 1):
            s_acc = evaluate_model(student, test_loader, device)
            s_cos = get_cosine_similarity_model(student, flip_vectors)
            s_z   = (s_cos - wmark_mean) / max(wmark_std, EPS)
            epoch_hist.append(ep + 1)
            acc_hist.append(s_acc)
            z_hist.append(s_z)
            print(f"  Epoch {ep+1:3d}: acc {s_acc:6.2f}% | z={s_z:6.2f} {'✓' if s_z>Z_SIG else '✗'}")

    return {
        "percent": percent,
        "subset_size": subset_size,
        "epoch_hist": epoch_hist,
        "acc_hist": acc_hist,
        "z_hist": z_hist,
        "final_acc": acc_hist[-1],
        "final_z": z_hist[-1],
        "teacher_acc": teacher_acc,
        "teacher_z": t_z,
        "transfer": (acc_hist[-1] / max(1e-6, teacher_acc)) * 100.0,
        "removed": z_hist[-1] <= Z_SIG,
    }


def plot_tradeoff(all_results, teacher_acc, teacher_z, outdir, cfg_str):
    os.makedirs(outdir, exist_ok=True)
    plt.figure(figsize=(11, 8))

    for r in all_results:
        z = r["z_hist"]
        a = r["acc_hist"]
        label = f"{int(r['percent'])}% data ({r['subset_size']:,})"
        plt.plot(z, a, marker='o', linewidth=2, markersize=6, label=label, alpha=0.85)
        # annotate start/end
        plt.annotate("E0", (z[0], a[0]), xytext=(5, 6), textcoords="offset points", fontsize=8, alpha=0.7)
        plt.annotate(f"E{r['epoch_hist'][-1]}", (z[-1], a[-1]), xytext=(5, 6), textcoords="offset points", fontsize=8, alpha=0.7)

    # Teacher point
    plt.scatter([teacher_z], [teacher_acc], s=180, marker='*', edgecolor='white', linewidth=2, label="Teacher", zorder=5)
    plt.annotate(f"Teacher\n({teacher_acc:.1f}%, z={teacher_z:.1f})",
                 (teacher_z, teacher_acc), xytext=(10, 10), textcoords="offset points",
                 fontsize=10, fontweight="bold")

    # Significance threshold
    plt.axvline(Z_SIG, linestyle="--", linewidth=1.8, color="black", alpha=0.8, label=f"z = {Z_SIG:g}")

    # Light square grid
    plt.grid(True, which="both", alpha=0.15)

    plt.xlabel("Watermark Z-score", fontsize=13)
    plt.ylabel("Test Accuracy (%)", fontsize=13)
    plt.title("Distillation Attack: Accuracy vs Watermark Detectability (Saved Datasets)", fontsize=15, pad=12)

    plt.legend(loc="lower right", framealpha=0.9, fontsize=10)
    plt.tight_layout()
    plt.figtext(0.5, 0.01, cfg_str, ha="center", fontsize=9, style="italic", alpha=0.75)

    path_pdf = os.path.join(outdir, "accuracy_vs_zscore.pdf")
    path_png = os.path.join(outdir, "accuracy_vs_zscore.png")
    plt.savefig(path_pdf, dpi=300, bbox_inches="tight")
    plt.savefig(path_png, dpi=300, bbox_inches="tight")
    print(f"Saved: {path_pdf}")


def save_summary_csv(path, header, rows):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w") as f:
        f.write(",".join(header) + "\n")
        for r in rows:
            f.write(",".join(map(str, r)) + "\n")
    print(f"Saved: {path}")


# ===== 95% CI aggregation (across seeds) =====
def tcrit_95(n):
    # 95% two-sided Student-t critical values for small n; fallback to 1.96
    table = {2: 12.706, 3: 4.303, 4: 3.182, 5: 2.776, 6: 2.571, 7: 2.447,
             8: 2.365, 9: 2.306, 10: 2.262}
    return table.get(n, 1.96)

def mean_ci_95(xs):
    xs = np.asarray(xs, dtype=float)
    n = len(xs)
    mu = float(xs.mean()) if n else float('nan')
    if n <= 1:
        return mu, 0.0
    s = float(xs.std(ddof=1))
    ci = tcrit_95(n) * s / np.sqrt(n)
    return mu, ci

def aggregate_across_seeds(master_rows, out_base):
    """
    master_rows entries are strings/numbers in order:
    [K, c, seed, pct, subset_size, t_acc, t_z, final_acc, final_z, transfer, removed]
    Aggregate final_acc and final_z across seeds for each (K, c, pct).
    """
    buckets = defaultdict(lambda: {"acc": [], "z": [], "subset_sizes": []})
    for row in master_rows:
        K, c, seed, pct, subset_size, t_acc, t_z, final_acc, final_z, transfer, removed = row
        key = (int(K), float(c), float(pct))
        buckets[key]["acc"].append(float(final_acc))
        buckets[key]["z"].append(float(final_z))
        buckets[key]["subset_sizes"].append(int(subset_size))

    # Write aggregated CSV
    agg_rows = []
    os.makedirs(out_base, exist_ok=True)
    csv_path = os.path.join(out_base, "distill_master_agg_95ci.csv")
    with open(csv_path, "w") as f:
        f.write("K,c,percent,n,mean_acc,ci95_acc,mean_z,ci95_z,mean_subset\n")
        for (K, c, pct), d in sorted(buckets.items()):
            n = len(d["acc"])
            mu_acc, ci_acc = mean_ci_95(d["acc"])
            mu_z,   ci_z   = mean_ci_95(d["z"])
            mean_subset = int(np.mean(d["subset_sizes"])) if d["subset_sizes"] else 0
            f.write(f"{K},{c},{pct},{n},{mu_acc:.6f},{ci_acc:.6f},{mu_z:.6f},{ci_z:.6f},{mean_subset}\n")
            agg_rows.append((K, c, pct, n, mu_acc, ci_acc, mu_z, ci_z, mean_subset))
    print(f"Saved: {csv_path}")

    # Make per-(K,c) errorbar plots: Accuracy vs % and Z vs %
    by_combo = defaultdict(list)
    for K, c, pct, n, mu_acc, ci_acc, mu_z, ci_z, mean_subset in agg_rows:
        by_combo[(K, c)].append((pct, mu_acc, ci_acc, mu_z, ci_z))

    for (K, c), rows in by_combo.items():
        rows.sort(key=lambda x: x[0])
        pcts   = [r[0] for r in rows]
        acc_mu = [r[1] for r in rows]; acc_ci = [r[2] for r in rows]
        z_mu   = [r[3] for r in rows]; z_ci   = [r[4] for r in rows]

        combo_dir = os.path.join(out_base, f"K{K}_c{c}")
        os.makedirs(combo_dir, exist_ok=True)

        # Accuracy vs % (95% CI)
        plt.figure(figsize=(7.2, 5.2))
        plt.errorbar(pcts, acc_mu, yerr=acc_ci, fmt='-o', capsize=4, linewidth=2)
        plt.grid(True, which="both", alpha=0.15)
        plt.xlabel("Train subset used for KD (%)")
        plt.ylabel("Final Test Accuracy (%)")
        plt.title(f"Accuracy (mean ±95% CI) — K={K}, c={c}")
        path = os.path.join(combo_dir, "acc_vs_pct_95ci.png")
        plt.savefig(path, dpi=300, bbox_inches="tight"); plt.close()
        print(f"Saved: {path}")

        # Z-score vs % (95% CI), with threshold
        plt.figure(figsize=(7.2, 5.2))
        plt.errorbar(pcts, z_mu, yerr=z_ci, fmt='-o', capsize=4, linewidth=2)
        plt.axhline(Z_SIG, linestyle="--", color="black", alpha=0.8, linewidth=1.5, label=f"z={Z_SIG:g}")
        plt.grid(True, which="both", alpha=0.15)
        plt.xlabel("Train subset used for KD (%)")
        plt.ylabel("Final Watermark Z-score")
        plt.title(f"Z-score (mean ±95% CI) — K={K}, c={c}")
        plt.legend()
        path = os.path.join(combo_dir, "z_vs_pct_95ci.png")
        plt.savefig(path, dpi=300, bbox_inches="tight"); plt.close()
        print(f"Saved: {path}")


# ===========================
# MAIN
# ===========================
def main():
    device = get_device()
    set_seed(GLOBAL_ATTACK_SEED)
    os.makedirs(OUT_BASE, exist_ok=True)

    master_rows = []
    master_header = [
        "K", "c", "seed", "percent", "subset_size",
        "teacher_acc", "teacher_z", "final_acc", "final_z",
        "transfer_percent", "removed"
    ]

    for seed in SEED_LIST:
        export_dir = find_export_dir(seed)
        if not export_dir:
            print(f"[WARN] No exported dataset dir found for seed={seed} in patterns {EXPORT_DIR_PATTERNS}. Skipping seed.")
            continue

        # Load SAVED datasets for this seed
        try:
            train_ds, val_ds, test_ds, test_loader = load_exported_datasets(export_dir, override_test_bs=TEST_BS)
        except FileNotFoundError as e:
            print(f"[WARN] {e} — Skipping seed={seed}.")
            continue

        for K in K_LIST:
            for c in C_LIST:
                print("\n" + "#" * 80)
                print(f"Grid combo: K={K}, c={c}, seed={seed}")
                print("#" * 80)

                teacher_path = find_teacher_path(K, c, seed)
                if not teacher_path:
                    print(f"[WARN] Teacher not found for K={K}, c={c}, seed={seed}. Skipping combo.")
                    continue

                flip_path = find_flip_vectors_path(K, c, seed, teacher_path)
                if not flip_path:
                    print(f"[WARN] Flip vectors not found for K={K}, c={c}, seed={seed}. Skipping combo.")
                    continue

                # Load teacher + flips
                teacher = ResNet18_CIFAR100().to(device)
                teacher.load_state_dict(torch.load(teacher_path, map_location=device))
                teacher.eval()

                flip_vectors = torch.load(flip_path, map_location=device)
                for kname in list(flip_vectors.keys()):
                    v = flip_vectors[kname]
                    if v is not None and isinstance(v, torch.Tensor):
                        flip_vectors[kname] = v.to(device=device, dtype=torch.float32)

                # Teacher baseline
                with torch.no_grad():
                    t_acc = evaluate_model(teacher, test_loader, device)
                    t_cos = get_cosine_similarity_model(teacher, flip_vectors)
                t_z = (t_cos - WMARK_MEAN) / max(WMARK_STD, EPS)
                print(f"Teacher baseline (seed={seed}): {t_acc:.2f}% acc | z={t_z:.2f}")

                # Output dir per combo
                combo_out = os.path.join(OUT_BASE, f"K{K}_c{c}_seed{seed}")
                os.makedirs(combo_out, exist_ok=True)

                # Run KD for each percentage
                results = []
                for pct in PERCENTAGES:
                    r = run_single_percentage(
                        teacher_model=teacher,
                        flip_vectors=flip_vectors,
                        train_dataset=train_ds,
                        test_loader=test_loader,
                        percent=pct,
                        device=device,
                        epochs=EPOCHS,
                        lr=LR,
                        student_batch_size=STUDENT_BS,
                        temperature=TEMPERATURE,
                        alpha=ALPHA,
                        seed=GLOBAL_ATTACK_SEED,
                        wmark_mean=WMARK_MEAN,
                        wmark_std=WMARK_STD
                    )
                    results.append(r)
                    master_rows.append([
                        K, c, seed, pct, r["subset_size"],
                        f"{t_acc:.4f}", f"{t_z:.6f}",
                        f"{r['final_acc']:.4f}", f"{r['final_z']:.6f}",
                        f"{r['transfer']:.4f}", int(r["removed"])
                    ])

                # Per-combo summary CSV
                per_combo_csv = os.path.join(combo_out, "distill_summary.csv")
                save_summary_csv(
                    per_combo_csv,
                    header=["percent","subset_size","final_acc","final_z","transfer","removed"],
                    rows=[
                        [r["percent"], r["subset_size"], f"{r['final_acc']:.6f}", f"{r['final_z']:.6f}",
                         f"{r['transfer']:.6f}", int(r["removed"])]
                        for r in results
                    ]
                )

                # Plot per-combo tradeoff
                cfg = (f"K={K}, c={c}, seed={seed} | T={TEMPERATURE}, alpha={ALPHA}, "
                       f"epochs={EPOCHS}, lr={LR}, bs(student)={STUDENT_BS}")
                plot_tradeoff(results, t_acc, t_z, combo_out, cfg)

    # Master CSV across all combos
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    master_csv = os.path.join(OUT_BASE, f"distill_master_{timestamp}.csv")
    save_summary_csv(master_csv, master_header, master_rows)

    # Aggregate mean ± 95% CI across seeds and make error-bar plots per (K,c)
    aggregate_across_seeds(master_rows, OUT_BASE)


if __name__ == "__main__":
    main()