# -*- coding: utf-8 -*-
"""
Quantization attacks on HIGHEST-VALIDATION models, averaged across seeds.

Layout assumed:
  repo/
    src/...
    example/
      quantization_attack.py   <-- this file (run from anywhere)
      untrusted_tinyimagenet_results/   <-- all .pt files live here (recursively)

Outputs in:
  example/quant_attack_tinyimagenet_figs/
    K{K}_z_bars.pdf
    K{K}_signal_vs_acc.pdf
    quant_summary.csv
"""

import os, sys, re, copy
from pathlib import Path
from typing import Dict, Tuple, Optional
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import pandas as pd

# ── Paths ────────────────────────────────────────────────────────────────────
EXAMPLE_DIR = Path(__file__).resolve().parent          # .../example
ROOT = EXAMPLE_DIR.parent                               # repo root
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))                      # allow `from src...` imports

RESULTS_ROOT = EXAMPLE_DIR / "untrusted_tinyimagenet_results"   # <-- your results live HERE
if not RESULTS_ROOT.exists():
    raise FileNotFoundError(f"Expected results dir not found: {RESULTS_ROOT}")

OUT_DIR = EXAMPLE_DIR / "quant_attack_tinyimagenet_figs"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ── Imports (now that sys.path is set) ───────────────────────────────────────
from src.ResNetTinyImageNet import ResNet18_TinyImageNet
from src.data_utils import get_tinyimagenet_transforms, get_tinyimagenet_dataset, get_tinyimagenet_dataloaders
from src.utils import get_cosine_similarity_model

# ── Config ───────────────────────────────────────────────────────────────────
SEEDS  = [0, 1, 2]
K_LIST = [2, 4, 8, 16, 32, 64, 128]
C_LIST = [0.025, 0.05, 0.075, 0.1]

DETECTION_MEAN = 0.0
DETECTION_STD  = 0.0088

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EVAL_BATCH = 128

# ── Data ─────────────────────────────────────────────────────────────────────
transform_train, transform_test = get_tinyimagenet_transforms()
train_dataset, val_dataset, test_dataset = get_tinyimagenet_dataset(transform_train, transform_test)
_, _, test_loader = get_tinyimagenet_dataloaders(train_dataset, val_dataset, test_dataset, batch_size=EVAL_BATCH)

# ── Filename parsing & index (scan recursively under RESULTS_ROOT) ───────────
RE_BEST = re.compile(
    r"highest_validation_accuracy_model_"
    r"K(?P<K>\d+)_lr(?P<lr>[\d.]+)_c(?P<c>[\d.]+)_steps(?P<steps>\d+)_bs(?P<bs>\d+)_seed(?P<seed>\d+)\.pt$"
)
RE_KEY = re.compile(
    r"tinyimagenet_flip_vectors_"
    r"K(?P<K>\d+)_lr(?P<lr>[\d.]+)_c(?P<c>[\d.]+)_steps(?P<steps>\d+)_bs(?P<bs>\d+)_seed(?P<seed>\d+)\.pt$"
)

def _round_c(x: float, nd: int = 6) -> float:
    return float(f"{x:.{nd}f}")

def scan_files() -> Tuple[Dict[Tuple[int,int,float], str], Dict[Tuple[int,int,float], str]]:
    best_models: Dict[Tuple[int,int,float], str] = {}
    flip_keys:   Dict[Tuple[int,int,float], str] = {}
    for p in RESULTS_ROOT.rglob("*.pt"):
        name = p.name
        s = str(p)
        m = RE_BEST.match(name)
        if m:
            K = int(m["K"]); seed = int(m["seed"]); cval = _round_c(float(m["c"]))
            key = (K, seed, cval)
            if key not in best_models or os.path.getmtime(s) > os.path.getmtime(best_models[key]):
                best_models[key] = s
            continue
        m = RE_KEY.match(name)
        if m:
            K = int(m["K"]); seed = int(m["seed"]); cval = _round_c(float(m["c"]))
            key = (K, seed, cval)
            if key not in flip_keys or os.path.getmtime(s) > os.path.getmtime(flip_keys[key]):
                flip_keys[key] = s
            continue
    return best_models, flip_keys

BEST_INDEX, KEY_INDEX = scan_files()

def find_paths(K: int, seed: int, c_target: float) -> Tuple[Optional[str], Optional[str]]:
    """Exact c match first; else nearest c for (K, seed) within 1e-4."""
    c_r = _round_c(c_target)
    exact = (K, seed, c_r)
    if exact in BEST_INDEX and exact in KEY_INDEX:
        return BEST_INDEX[exact], KEY_INDEX[exact]
    cs = sorted({c for (k,s,c) in BEST_INDEX.keys() if k == K and s == seed})
    if not cs:
        return None, None
    nearest = min(cs, key=lambda c: abs(c - c_r))
    if abs(nearest - c_r) <= 1e-4:
        key = (K, seed, nearest)
        if key in BEST_INDEX and key in KEY_INDEX:
            return BEST_INDEX[key], KEY_INDEX[key]
    return None, None

# ── Quantization helpers ─────────────────────────────────────────────────────
def _quant_per_tensor_symmetric(w: torch.Tensor, bits: int) -> torch.Tensor:
    qmax = 2 ** (bits - 1) - 1
    max_abs = w.abs().max()
    if max_abs == 0:
        return w.clone()
    scale = max_abs / qmax
    qw = torch.clamp(torch.round(w / scale), -qmax, qmax)
    return qw * scale

def _quant_per_channel_out_symmetric(w: torch.Tensor, bits: int) -> torch.Tensor:
    if w.dim() not in (2, 4):
        return _quant_per_tensor_symmetric(w, bits)
    qmax = 2 ** (bits - 1) - 1
    w_flat = w.contiguous().view(w.shape[0], -1)
    max_abs = w_flat.abs().max(dim=1).values
    max_abs[max_abs == 0] = 1.0
    scales = max_abs / qmax
    s = scales.view(-1, 1) if w.dim() == 2 else scales.view(-1, 1, 1, 1)
    qw = torch.round(w / s).clamp_(-qmax, qmax)
    return qw * s

def quantize_model_weights(model: nn.Module, mode: str) -> nn.Module:
    """
    mode ∈ {"static8", "static4", "dynamic8"}:
      - static8/static4 : per-tensor symmetric int8/int4 (Conv/Linear weights)
      - dynamic8        : per-output-channel symmetric int8 (Conv/Linear weights)
    Bias/BN left in float.
    """
    m = copy.deepcopy(model).cpu()
    with torch.no_grad():
        for _, p in m.named_parameters():
            if not p.requires_grad:
                continue
            if p.dim() in (2, 4):
                if mode == "static8":
                    q = _quant_per_tensor_symmetric(p.data, 8)
                elif mode == "static4":
                    q = _quant_per_tensor_symmetric(p.data, 4)
                elif mode == "dynamic8":
                    q = _quant_per_channel_out_symmetric(p.data, 8)
                else:
                    raise ValueError(f"Unknown mode: {mode}")
                p.data.copy_(q)
    return m.to(DEVICE)

# ── Eval helpers ─────────────────────────────────────────────────────────────
@torch.no_grad()
def evaluate_top1(model: nn.Module, loader) -> float:
    model.eval()
    correct, total = 0, 0
    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        pred = model(x).argmax(1)
        total += y.size(0)
        correct += (pred == y).sum().item()
    return 100.0 * correct / total

def watermark_z_score(model: nn.Module, flip_vectors: Dict[str, torch.Tensor]) -> float:
    cos = get_cosine_similarity_model(model, flip_vectors)
    return float((cos - DETECTION_MEAN) / DETECTION_STD)

# ── Aggregate across seeds ───────────────────────────────────────────────────
def average_over_seeds(K: int, c: float) -> dict:
    zs_o, zs_s8, zs_s4, zs_d8 = [], [], [], []
    acc_o, acc_s8, acc_s4, acc_d8 = [], [], [], []

    for seed in SEEDS:
        model_path, key_path = find_paths(K, seed, c)
        if model_path is None or key_path is None:
            print(f"[WARN] Missing best model or key for K={K}, c={c}, seed={seed}")
            continue

        model = ResNet18_TinyImageNet().to(DEVICE)
        model.load_state_dict(torch.load(model_path, map_location=DEVICE))
        flip_vectors = torch.load(key_path, map_location=DEVICE)

        # Original
        z0 = watermark_z_score(model, flip_vectors)
        a0 = evaluate_top1(model, test_loader)

        # Static 8-bit
        m_s8 = quantize_model_weights(model, "static8")
        z_s8 = watermark_z_score(m_s8, flip_vectors)
        a_s8 = evaluate_top1(m_s8, test_loader)

        # Static 4-bit
        m_s4 = quantize_model_weights(model, "static4")
        z_s4 = watermark_z_score(m_s4, flip_vectors)
        a_s4 = evaluate_top1(m_s4, test_loader)

        # Dynamic 8-bit
        m_d8 = quantize_model_weights(model, "dynamic8")
        z_d8 = watermark_z_score(m_d8, flip_vectors)
        a_d8 = evaluate_top1(m_d8, test_loader)

        zs_o.append(z0);    acc_o.append(a0)
        zs_s8.append(z_s8); acc_s8.append(a_s8)
        zs_s4.append(z_s4); acc_s4.append(a_s4)
        zs_d8.append(z_d8); acc_d8.append(a_d8)

    mean = (lambda x: float(np.mean(x)) if x else float("nan"))
    return {
        "n_seeds": len(zs_o),
        "z_orig": mean(zs_o),   "acc_orig": mean(acc_o),
        "z_s8":   mean(zs_s8),  "acc_s8":   mean(acc_s8),
        "z_s4":   mean(zs_s4),  "acc_s4":   mean(acc_s4),
        "z_d8":   mean(zs_d8),  "acc_d8":   mean(acc_d8),
    }

# ── Plotting ─────────────────────────────────────────────────────────────────
def _light_grid(ax):
    ax.grid(True, which="both", alpha=0.25)
    ax.set_axisbelow(True)

def plot_for_K(K: int, mat: dict):
    c_vals = sorted(mat.keys())
    if not c_vals:
        return

    z_o  = [mat[c]["z_orig"] for c in c_vals]
    z_s8 = [mat[c]["z_s8"] for c in c_vals]
    z_s4 = [mat[c]["z_s4"] for c in c_vals]
    z_d8 = [mat[c]["z_d8"] for c in c_vals]

    a_o  = [mat[c]["acc_orig"] for c in c_vals]
    a_s8 = [mat[c]["acc_s8"] for c in c_vals]
    a_s4 = [mat[c]["acc_s4"] for c in c_vals]
    a_d8 = [mat[c]["acc_d8"] for c in c_vals]

    # 1) Z-score bars
    import numpy as np
    x = np.arange(len(c_vals)); w = 0.2
    fig1, ax1 = plt.subplots(figsize=(12, 5))
    ax1.bar(x + 0*w, z_o,  width=w, label="Original")
    ax1.bar(x + 1*w, z_s8, width=w, label="Static 8-bit")
    ax1.bar(x + 2*w, z_s4, width=w, label="Static 4-bit")
    ax1.bar(x + 3*w, z_d8, width=w, label="Dynamic 8-bit")
    ax1.axhline(4.0, ls="--", lw=1.0, color="r", label="Detection Threshold (Z=4)")
    ax1.set_xticks(x + 1.5*w)
    ax1.set_xticklabels([f"{c:.3f}".rstrip('0').rstrip('.') for c in c_vals])
    ax1.set_xlabel("Watermark Strength (c)")
    ax1.set_ylabel("Z Score")
    ax1.set_title(f"Watermark Detection Z-Scores After Quantization (K={K})")
    _light_grid(ax1); ax1.legend()
    fig1.tight_layout()
    fig1.savefig(OUT_DIR / f"K{K}_z_bars.pdf", bbox_inches='tight')
    plt.close(fig1)

    # 2) Z vs Accuracy scatter
    fig2, ax2 = plt.subplots(figsize=(12, 5))
    def series(zs, accs, label, marker):
        ax2.scatter(zs, accs, label=label, marker=marker)
        for zi, ai, c in zip(zs, accs, c_vals):
            ax2.annotate(f"c={c:.3f}".rstrip('0').rstrip('.'), (zi, ai), fontsize=8)
    series(z_o,  a_o,  "Original",     "s")
    series(z_s8, a_s8, "Static 8-bit", "o")
    series(z_s4, a_s4, "Static 4-bit", "^")
    series(z_d8, a_d8, "Dynamic 8-bit","d")
    ax2.set_xlabel("Watermark Signal Strength (Z-Score)")
    ax2.set_ylabel("Test Accuracy (%)")
    ax2.set_title(f"Effect of Quantization on Watermark Signal and Model Accuracy (K={K})")
    _light_grid(ax2); ax2.legend()
    fig2.tight_layout()
    fig2.savefig(OUT_DIR / f"K{K}_signal_vs_acc.pdf", bbox_inches='tight')
    plt.close(fig2)

# ── Main ─────────────────────────────────────────────────────────────────────
def main():
    print(f"Device: {DEVICE}")
    summary_rows = []

    for K in K_LIST:
        per_c = {}
        for c in C_LIST:
            m = average_over_seeds(K, c)
            if m["n_seeds"] == 0:
                print(f"[WARN] No valid seeds for K={K}, c={c}. Skipping.")
                continue
            per_c[c] = m
            summary_rows.append({"K": K, "c": c, **m})
        if per_c:
            plot_for_K(K, per_c)
        else:
            print(f"[WARN] No data for K={K}; no plots generated.")

    if summary_rows:
        pd.DataFrame(summary_rows).to_csv(OUT_DIR / "quant_summary.csv", index=False)
        print(f"Saved CSV summary → {OUT_DIR / 'quant_summary.csv'}")
    print(f"Done. Plots saved in {OUT_DIR} (2 per K).")

if __name__ == "__main__":
    main()