#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Non-linear attack on mixup privacy for multiple taus.

For each tau in a given list:
  * compute dataset stats (r, v_hat, c), get mf via theory
  * train a lightweight U-Net on Tiny-ImageNet mixups (train on public)
  * evaluate zero-shot on CIFAR-10 mixups (attack on private), report SNR

Finally, save ONE figure with 3 rows x len(taus) columns:
  - Row 1: GT image corresponding to the best recovered example at the largest tau
  - Row 2: Mixup input for that same example index, generated under each tau
  - Row 3: Recovered image (U-Net output) under each tau
Column headers show tau as 10^k in LaTeX (e.g., 10^{-1}).

All console prints are also logged to a file via Tee.

Note: Uses ImageNet-style normalization and resizes inputs to 224x224.
"""

import os
import sys
import math
import argparse
import random
from datetime import datetime
from typing import Tuple, List, Dict, Any

import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt

# ---------------------------
# Tee logging (prints go to stdout AND a file)
# ---------------------------

class Tee(object):
    def __init__(self, filepath: str, mode: str = "w"):
        self.file = open(filepath, mode)
        self.stdout = sys.stdout
        self.stderr = sys.stderr

    def write(self, data):
        self.stdout.write(data)
        self.file.write(data)

    def flush(self):
        self.stdout.flush()
        self.file.flush()

    def close(self):
        try:
            self.file.close()
        except Exception:
            pass

# ---------------------------
# Normalization utilities
# ---------------------------

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

def get_transforms(dataset_type: str):
    dataset_type = dataset_type.lower()
    if dataset_type in ["cifar10", "cifar100", "tiny-imagenet"]:
        return transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ])
    elif dataset_type == "mnist":
        return transforms.Compose([
            transforms.Resize(224),
            transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ])
    else:
        raise ValueError(f"Unsupported dataset: {dataset_type}")

def get_dataset(name: str):
    name = name.lower()
    tfm = get_transforms(name)
    if name == "tiny-imagenet":
        train_dir = os.path.join("./data", "tiny-imagenet-200", "train")
        val_dir   = os.path.join("./data", "tiny-imagenet-200", "val")
        test_dir  = os.path.join("./data", "tiny-imagenet-200", "test")
        root_eval = val_dir if os.path.isdir(val_dir) else test_dir
        if not (os.path.isdir(train_dir) and os.path.isdir(root_eval)):
            raise FileNotFoundError("Tiny-ImageNet expected in ./data/tiny-imagenet-200/{train,val or test}")
        train = torchvision.datasets.ImageFolder(train_dir, transform=tfm)
        test  = torchvision.datasets.ImageFolder(root_eval, transform=tfm)
    elif name == "cifar10":
        train = torchvision.datasets.CIFAR10(root="./data", train=True,  download=True, transform=tfm)
        test  = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=tfm)
    elif name == "cifar100":
        train = torchvision.datasets.CIFAR100(root="./data", train=True,  download=True, transform=tfm)
        test  = torchvision.datasets.CIFAR100(root="./data", train=False, download=True, transform=tfm)
    elif name == "mnist":
        train = torchvision.datasets.MNIST(root="./data", train=True,  download=True, transform=tfm)
        test  = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=tfm)
    else:
        raise ValueError(f"Unsupported dataset: {name}")
    return train, test

def clamp_imagenet_normalized(x: torch.Tensor) -> torch.Tensor:
    mean = torch.tensor(IMAGENET_MEAN, device=x.device).view(1, -1, 1, 1)
    std  = torch.tensor(IMAGENET_STD,  device=x.device).view(1, -1, 1, 1)
    x01 = (x * std + mean).clamp(0.0, 1.0)
    return (x01 - mean) / std

def chw_to_numpy_img01(x_chw: torch.Tensor) -> np.ndarray:
    mean = torch.tensor(IMAGENET_MEAN, device=x_chw.device).view(3,1,1)
    std  = torch.tensor(IMAGENET_STD,  device=x_chw.device).view(3,1,1)
    x01 = (x_chw * std + mean).clamp(0,1)
    return x01.permute(1,2,0).detach().cpu().numpy()

# ---------------------------
# Subsetting & stats (r, v_hat, c)
# ---------------------------

def make_subdataset(ds, size: int, seed: int):
    n = len(ds)
    rng = np.random.default_rng(seed)
    idx = rng.choice(n, size=min(size, n), replace=False)
    return Subset(ds, sorted(idx.tolist()))

@torch.no_grad()
def estimate_average_distance(dataset, device="cpu", max_images: int = 2048, batch_size: int = 64) -> float:
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=False)
    xs, total = [], 0
    for imgs, _ in loader:
        imgs = imgs.to(device)
        xs.append(imgs)
        total += imgs.size(0)
        if total >= max_images:
            break
    X = torch.cat(xs, dim=0)  # [N,C,H,W]
    N = X.size(0)
    perm = torch.randperm(N, device=X.device)
    perm = torch.roll(perm, shifts=1)
    dists = torch.norm((X - X[perm]).view(N, -1), p=2, dim=1)
    return float(dists.mean().item())

@torch.no_grad()
def estimate_per_pixel_variance_global_mean(dataset, device="cpu", max_images: int = 2048, batch_size: int = 64) -> Tuple[float,int]:
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=False)
    xs, total = [], 0
    for imgs, _ in loader:
        imgs = imgs.to(device)
        xs.append(imgs)
        total += imgs.size(0)
        if total >= max_images:
            break
    X = torch.cat(xs, dim=0)   # [N,C,H,W]
    N, C, H, W = X.shape
    d = C*H*W
    Xv = X.view(N, -1)
    mu = Xv.mean()                                # global mean
    v_hat = ((Xv - mu)**2).mean().item()
    return v_hat, d

def compute_c_factor(r: float, v_hat: float, d: int) -> float:
    return (r*r) / (2.0 * d * v_hat + 1e-20)

def mf_from_tau(alpha: float, tau: float, c: float) -> float:
    # mf >= sqrt( (alpha^2/(tau(1-alpha)^2) - 1) / (2c) )
    a2 = alpha*alpha
    inner = max(a2 / (max(tau,1e-20) * (1-alpha)**2) - 1.0, 0.0)
    return float(np.sqrt(inner / (2.0 * max(c,1e-20))))

# ---------------------------
# Mixup pair dataset (on-the-fly, deterministic per index)
# ---------------------------

class MixupPairDataset(torch.utils.data.Dataset):
    """
    For each index i:
      input  = alpha * x_i + (1-alpha) * (x_j + n), with ||n|| = mf * r
      target = x_i

    IMPORTANT: partner j is sampled deterministically per index using (seed + idx),
    so that the same base target index can be reproduced across different taus.
    """
    def __init__(self, base_subset, alpha: float, r: float, mf: float, seed: int = 0):
        self.base = base_subset
        self.alpha = alpha
        self.r = r
        self.mf = mf
        self.seed = int(seed)

    def __len__(self):
        return len(self.base)

    @torch.no_grad()
    def _noise(self, x: torch.Tensor, target_norm: float) -> torch.Tensor:
        B = x.size(0)
        noise = torch.randn_like(x)
        norms = torch.norm(noise.view(B,-1), p=2, dim=1)
        scales = (target_norm / (norms + 1e-12)).view(B,1,1,1)
        return noise * scales

    def __getitem__(self, idx: int):
        x_i, _ = self.base[idx]

        # deterministic RNG per (seed, idx) so we can reproduce "corresponding image" across taus
        rng = np.random.default_rng(self.seed + int(idx))
        j = int(rng.integers(0, len(self.base)))
        if j == idx:
            j = (j + 1) % len(self.base)
        x_j, _ = self.base[j]

        # build single-element batch for the noise util
        x_j_b = x_j.unsqueeze(0)
        n = self._noise(x_j_b, target_norm=self.mf * self.r)[0]

        mix = self.alpha * x_i + (1.0 - self.alpha) * (x_j + n)
        mix = clamp_imagenet_normalized(mix.unsqueeze(0))[0]
        return mix, x_i

# ---------------------------
# U-Net (lightweight)
# ---------------------------

class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, padding=1),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class UNetSmall(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, base=32):
        super().__init__()
        self.down1 = DoubleConv(in_ch, base)
        self.pool1 = nn.MaxPool2d(2)
        self.down2 = DoubleConv(base, base*2)
        self.pool2 = nn.MaxPool2d(2)
        self.down3 = DoubleConv(base*2, base*4)
        self.pool3 = nn.MaxPool2d(2)
        self.bott  = DoubleConv(base*4, base*8)

        self.up3   = nn.ConvTranspose2d(base*8, base*4, 2, stride=2)
        self.dec3  = DoubleConv(base*8, base*4)
        self.up2   = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.dec2  = DoubleConv(base*4, base*2)
        self.up1   = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.dec1  = DoubleConv(base*2, base)

        self.outc  = nn.Conv2d(base, out_ch, 1)

    def forward(self, x):
        d1 = self.down1(x)
        p1 = self.pool1(d1)
        d2 = self.down2(p1)
        p2 = self.pool2(d2)
        d3 = self.down3(p2)
        p3 = self.pool3(d3)
        bt = self.bott(p3)

        u3 = self.up3(bt)
        x3 = torch.cat([u3, d3], dim=1)
        x3 = self.dec3(x3)

        u2 = self.up2(x3)
        x2 = torch.cat([u2, d2], dim=1)
        x2 = self.dec2(x2)

        u1 = self.up1(x2)
        x1 = torch.cat([u1, d1], dim=1)
        x1 = self.dec1(x1)

        y  = self.outc(x1)
        y  = clamp_imagenet_normalized(y)
        return y

# ---------------------------
# Training / evaluation
# ---------------------------

def snr_db(x: torch.Tensor, xhat: torch.Tensor) -> torch.Tensor:
    B = x.size(0)
    xv = x.view(B, -1); yv = xhat.view(B, -1)
    sig = torch.sum(xv**2, dim=1)
    err = torch.sum((xv - yv)**2, dim=1) + 1e-20
    return 10.0 * torch.log10(sig / err)

def train_unet(model, loader, epochs: int, device, lr: float = 1e-3, tv_weight: float = 1e-4, l2_weight: float = 0.0):
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    l1  = nn.L1Loss()
    for ep in range(1, epochs+1):
        model.train()
        losses = []
        for mix, tgt in loader:
            mix = mix.to(device); tgt = tgt.to(device)
            pred = model(mix)
            loss_l1 = l1(pred, tgt)
            dh = torch.abs(pred[:, :, 1:, :] - pred[:, :, :-1, :]).mean()
            dw = torch.abs(pred[:, :, :, 1:] - pred[:, :, :, :-1]).mean()
            loss_tv = tv_weight * (dh + dw)
            loss_l2 = l2_weight * (pred.pow(2).mean())
            loss = loss_l1 + loss_tv + loss_l2
            opt.zero_grad()
            loss.backward()
            opt.step()
            losses.append(float(loss.item()))
        print(f"[train] epoch {ep:02d}/{epochs}  loss={np.mean(losses):.5f}")

@torch.no_grad()
def evaluate_unet(model, loader, device, track_best: bool = True, max_batches=None):
    model.eval()
    snrs = []
    best = {"snr": -1e9, "gt": None, "in": None, "out": None, "global_index": None}
    seen = 0
    with torch.no_grad():
        for b, (mix, tgt) in enumerate(loader):
            B = mix.size(0)
            mix = mix.to(device); tgt = tgt.to(device)
            pred = model(mix)
            s = snr_db(tgt, pred)
            snrs.append(s.cpu())
            if track_best:
                k = int(torch.argmax(s).item())
                if s[k].item() > best["snr"]:
                    best["snr"] = s[k].item()
                    best["gt"]  = tgt[k].detach().cpu()
                    best["in"]  = mix[k].detach().cpu()
                    best["out"] = pred[k].detach().cpu()
                    best["global_index"] = seen + k
            seen += B
            if (max_batches is not None) and (b+1 >= max_batches):
                break
    snrs = torch.cat(snrs, dim=0)
    return float(snrs.mean().item()), float(snrs.std().item()), best if track_best else None

# ---------------------------
# Helpers for figure & tau formatting
# ---------------------------

def tau_to_tex_power10(tau: float) -> str:
    if tau <= 0:
        return r"$\tau=0$"
    exp = int(np.round(np.log10(tau)))
    if np.isclose(tau, 10.0**exp):
        return rf"$\tau=10^{{{exp}}}$"
    # fallback for non-exact powers (keeps things nice if you change taus later)
    mant = tau / (10.0**exp)
    return rf"$\tau={mant:.1f}\times 10^{{{exp}}}$"

# ---------------------------
# Main orchestration (multi-tau)
# ---------------------------

def main():
    ap = argparse.ArgumentParser("Non-linear (U-Net) attack across multiple taus.")
    ap.add_argument("--alpha", type=float, default=0.7)
    ap.add_argument("--taus", type=str, default="1e-00, 1e-01, 1e-02, 1e-03, 1e-04, 1e-05, 1e-06",
                    help="Comma-separated list of tau values, e.g. '1e-1,1e-2,1e-3'")
    ap.add_argument("--train_size", type=int, default=4000, help="Tiny-ImageNet subset size")
    ap.add_argument("--test_size", type=int, default=2000, help="CIFAR-10 subset size")
    ap.add_argument("--batch", type=int, default=32)
    ap.add_argument("--epochs", type=int, default=30)
    ap.add_argument("--lr", type=float, default=1e-3)
    ap.add_argument("--tv", type=float, default=1e-4)
    ap.add_argument("--l2", type=float, default=0.0)
    ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    ap.add_argument("--seed", type=int, default=0)
    ap.add_argument("--stats_max", type=int, default=2048)
    ap.add_argument("--outdir", type=str, default="unet_attack_figs")
    ap.add_argument("--logfile", type=str, default=None, help="Optional log filename (defaults to timestamped).")
    args = ap.parse_args()

    # Prepare outdir and Tee logger
    os.makedirs(args.outdir, exist_ok=True)
    log_name = args.logfile or f"log_multitau_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
    log_path = os.path.join(args.outdir, log_name)
    tee = Tee(log_path)
    sys.stdout = tee
    sys.stderr = tee
    print(f"[LOG] Tee logging to: {log_path}")

    # Repro
    random.seed(args.seed); np.random.seed(args.seed); torch.manual_seed(args.seed)

    device = torch.device(args.device)

    # Parse taus
    taus: List[float] = []
    for t in args.taus.split(","):
        t = t.strip()
        if t:
            taus.append(float(t))
    if len(taus) == 0:
        raise ValueError("No tau values provided. Use --taus like '1e-1,1e-2,1e-3'.")

    taus_sorted = sorted(taus, reverse=True)
    tau_max = taus_sorted[0]
    print(f"[CFG] taus (sorted desc): {taus_sorted}  | tau_max={tau_max}")

    # --- Datasets ---
    tiny_train, tiny_eval = get_dataset("tiny-imagenet")
    cif_train, cif_eval   = get_dataset("cifar10")

    tiny_sub = make_subdataset(tiny_train, args.train_size, seed=args.seed)
    cif_sub  = make_subdataset(cif_eval,  args.test_size,  seed=args.seed)

    # --- Shared stats baseline (d is same after transforms) ---
    print("\n[STATS] Estimating dataset statistics (independent of tau)...")
    r_tiny = estimate_average_distance(tiny_sub, device=device, max_images=args.stats_max, batch_size=args.batch)
    v_tiny, d = estimate_per_pixel_variance_global_mean(tiny_sub, device=device, max_images=args.stats_max, batch_size=args.batch)
    r_cif  = estimate_average_distance(cif_sub,  device=device, max_images=args.stats_max, batch_size=args.batch)
    v_cif, _d2 = estimate_per_pixel_variance_global_mean(cif_sub,  device=device, max_images=args.stats_max, batch_size=args.batch)
    print(f"[TINY] r={r_tiny:.4f}  v_hat={v_tiny:.4e}  d={d}")
    print(f"[CIFAR10] r={r_cif:.4f}  v_hat={v_cif:.4e}  d={d}")

    # Storage for per-tau results (for final panel)
    per_tau: Dict[float, Dict[str, Any]] = {}

    # --------- Run experiment for each tau ----------
    for tau in taus_sorted:
        print(f"\n===== [TAU {tau}] ({tau_to_tex_power10(tau)}) =====")
        # Compute c & mf for train and test at this tau
        c_tiny = compute_c_factor(r_tiny, v_tiny, d)
        mf_tiny = mf_from_tau(args.alpha, tau, c_tiny)
        c_cif = compute_c_factor(r_cif, v_cif, d)
        mf_cif = mf_from_tau(args.alpha, tau, c_cif)
        print(f"[tau={tau:g}] c_tiny={c_tiny:.4f} -> mf_tiny={mf_tiny:.4f} | c_cif={c_cif:.4f} -> mf_cif={mf_cif:.4f}")

        # Make mixup datasets (deterministic per index using seed for cross-tau correspondence)
        train_mix = MixupPairDataset(tiny_sub, alpha=args.alpha, r=r_tiny, mf=mf_tiny, seed=args.seed)
        test_mix  = MixupPairDataset(cif_sub,  alpha=args.alpha, r=r_cif,  mf=mf_cif,  seed=args.seed+1)

        train_loader = DataLoader(train_mix, batch_size=args.batch, shuffle=True,  num_workers=2, pin_memory=("cuda" in args.device))
        test_loader  = DataLoader(test_mix,  batch_size=args.batch, shuffle=False, num_workers=2, pin_memory=("cuda" in args.device))

        # Model
        model = UNetSmall(base=48).to(device)

        # Train
        print("[Training U-Net on Tiny-ImageNet mixups]")
        train_unet(model, train_loader, epochs=args.epochs, device=device, lr=args.lr, tv_weight=args.tv, l2_weight=args.l2)

        # Evaluate on CIFAR-10 (track best)
        print("[Evaluating on CIFAR-10 mixups]")
        mean_snr, std_snr, best = evaluate_unet(model, test_loader, device=device, track_best=True)
        print(f"[RESULT] tau={tau:g} | CIFAR-10 recovery SNR: {mean_snr:.2f} ± {std_snr:.2f} dB | best_snr={best['snr']:.2f} dB | best_index={best['global_index']}")

        # Save individual triplet (kept; this is NOT the multi-tau panel)
        indiv_path = os.path.join(args.outdir, f"best_unet_recovery_cifar10_tau{tau:g}.png")
        save_best_triplet(indiv_path, best, title=f"CIFAR-10 best (tau={tau:g}, alpha={args.alpha}) — SNR={max(best['snr'],0):.2f} dB")

        per_tau[tau] = {
            "mean_snr": mean_snr, "std_snr": std_snr, "best": best,
            "model": model, "mf_tiny": mf_tiny, "mf_cif": mf_cif,
            "test_mix": test_mix
        }

    # --------- Build the single summary figure (ONLY tau text allowed) ---------
    print("\n[FIG] Building multi-tau summary figure (3 rows x N columns)...")
    N = len(taus_sorted)

    # Use the "best global index" from the largest tau
    idx_best = per_tau[tau_max]["best"]["global_index"]
    gt_from_tau_max = per_tau[tau_max]["best"]["gt"]

    # Collect per-column images
    rows = {1: [], 2: [], 3: []}
    for tau in taus_sorted:
        rows[1].append(gt_from_tau_max)  # repeat GT
        mix_i, _ = per_tau[tau]["test_mix"][idx_best]
        rows[2].append(mix_i.detach().cpu())
        with torch.no_grad():
            out = per_tau[tau]["model"](mix_i.unsqueeze(0).to(device))[0].detach().cpu()
        rows[3].append(out)

    # Plot with ONLY tau as column titles
    fig_h = 3 * 3.2
    fig_w = N * 3.2
    fig = plt.figure(figsize=(fig_w, fig_h))

    col_titles = [tau_to_tex_power10(t) for t in taus_sorted]

    for r in range(1, 4):
        for c, img in enumerate(rows[r], start=1):
            ax = plt.subplot(3, N, (r-1)*N + c)
            ax.imshow(chw_to_numpy_img01(img))
            ax.axis("off")
            if r == 1:
                ax.set_title(col_titles[c-1], fontsize=20)

    out_panel = os.path.join(args.outdir, f"panel_multitau_{len(taus_sorted)}cols.png")
    plt.savefig(out_panel, dpi=140, bbox_inches="tight")
    plt.close()
    print(f"[saved] {out_panel}")

    print("\n[SUMMARY]")
    for tau in taus_sorted:
        m = per_tau[tau]["mean_snr"]; s = per_tau[tau]["std_snr"]; b = per_tau[tau]["best"]["snr"]
        print(f"  tau={tau:g} ({tau_to_tex_power10(tau)})  |  SNR={m:.2f}±{s:.2f} dB  | best={b:.2f} dB")

    # Restore stdio and close tee file
    sys.stdout = tee.stdout
    sys.stderr = tee.stderr
    tee.close()
    print(f"Logs saved to: {log_path}")
    print(f"Figure saved to: {out_panel}")

# ---------------------------
# Utility: save triplet (per-tau PNGs; not the multi-tau panel)
# ---------------------------

def save_best_triplet(figpath: str, best: dict, title: str):
    os.makedirs(os.path.dirname(figpath) or ".", exist_ok=True)
    fig = plt.figure(figsize=(9,3))
    for i, (lbl, img) in enumerate([("GT", best["gt"]), ("Mix", best["in"]), ("UNet", best["out"])]):
        ax = plt.subplot(1,3,i+1)
        ax.imshow(chw_to_numpy_img01(img))
        ax.set_title(lbl)
        ax.axis("off")
    plt.suptitle(title)
    plt.savefig(figpath, dpi=140, bbox_inches="tight")
    plt.close()
    print(f"[saved] {figpath}")

if __name__ == "__main__":
    main()
