"""
Linear reconstruction attack on mixup-with-noisy-partner.

This script saves ONE single figure on disk:
  - Columns are τ values (dynamic).
  - For each dataset we display three consecutive rows: Original, Mixup, Recovered.
  - The same image index is used across all τ for a dataset: the one with best SNR at the largest τ.
  - The ONLY text on the figure is τ above each column (no other labels/titles/annotations).

Datasets order in the figure (3 rows per dataset):
  rows 1-3 : MNIST        (Original, Mixup, Recovered)
  rows 4-6 : CIFAR-10     (Original, Mixup, Recovered)
  rows 7-9 : CIFAR-100    (Original, Mixup, Recovered)
  rows 10-12: Tiny-ImageNet (Original, Mixup, Recovered)
"""

import os
import sys
import io
import random
import argparse
from typing import List, Tuple, Dict
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt

# ---------------------------
# Simple Tee to mirror stdout/stderr to a file
# ---------------------------

class Tee(io.TextIOBase):
    def __init__(self, *streams):
        self.streams = streams

    def write(self, s):
        # Write to all streams; flush so logs appear promptly
        for st in self.streams:
            try:
                st.write(s)
            except Exception:
                pass
        for st in self.streams:
            try:
                st.flush()
            except Exception:
                pass
        return len(s)

    def flush(self):
        for st in self.streams:
            try:
                st.flush()
            except Exception:
                pass

    def isatty(self):
        # If any underlying stream is a tty, report True
        return any(getattr(st, "isatty", lambda: False)() for st in self.streams)

# ---------------------------
# Preprocess / (de)normalize helpers
# ---------------------------

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]


def tau_to_latex(tau: float) -> str:
    if tau <= 0:
        return r"$\tau=0$"
    exp = int(np.round(np.log10(tau)))
    if np.isclose(tau, 10.0**exp):
        return rf"$\tau=10^{{{exp}}}$"
    # fallback for non-exact powers (keeps things nice if you change taus later)
    mant = tau / (10.0**exp)
    return rf"$\tau={mant:.1f}\times 10^{{{exp}}}$"


def get_transforms(dataset_type: str):
    dataset_type = dataset_type.lower()
    if dataset_type in ["cifar10", "cifar100", "tiny-imagenet"]:
        return transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ])
    elif dataset_type == "mnist":
        return transforms.Compose([
            transforms.Resize(224),
            transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ])
    else:
        raise ValueError(f"Unsupported dataset type: {dataset_type}")

def get_dataset(dataset_type: str):
    dataset_type = dataset_type.lower()
    tfm = get_transforms(dataset_type)

    if dataset_type == "mnist":
        train = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=tfm)
        test  = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=tfm)
    elif dataset_type == "cifar10":
        train = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=tfm)
        test  = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=tfm)
    elif dataset_type == "cifar100":
        train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True, transform=tfm)
        test  = torchvision.datasets.CIFAR100(root="./data", train=False, download=True, transform=tfm)
    elif dataset_type == "tiny-imagenet":
        train_dir = os.path.join("./data", "tiny-imagenet-200", "train")
        val_dir   = os.path.join("./data", "tiny-imagenet-200", "val")
        test_dir  = os.path.join("./data", "tiny-imagenet-200", "test")
        test_root = val_dir if os.path.isdir(val_dir) else test_dir
        if not (os.path.isdir(train_dir) and os.path.isdir(test_root)):
            raise FileNotFoundError("Tiny-ImageNet expected at ./data/tiny-imagenet-200/{train,val or test}")
        train = torchvision.datasets.ImageFolder(root=train_dir, transform=tfm)
        test  = torchvision.datasets.ImageFolder(root=test_root,  transform=tfm)
    else:
        raise ValueError(f"Unsupported dataset type: {dataset_type}")
    return train, test

def denorm_to_01(x: torch.Tensor) -> torch.Tensor:
    mean = torch.tensor(IMAGENET_MEAN, device=x.device).view(1, -1, 1, 1)
    std  = torch.tensor(IMAGENET_STD,  device=x.device).view(1, -1, 1, 1)
    return (x * std + mean).clamp(0.0, 1.0)

def clamp_imagenet_normalized(x: torch.Tensor) -> torch.Tensor:
    """Clamp by denormalizing to [0,1], then renormalizing back."""
    mean = torch.tensor(IMAGENET_MEAN, device=x.device).view(1, -1, 1, 1)
    std  = torch.tensor(IMAGENET_STD,  device=x.device).view(1, -1, 1, 1)
    x01 = (x * std + mean).clamp(0.0, 1.0)
    return (x01 - mean) / std

def chw_to_numpy_img01(x_chw: torch.Tensor) -> np.ndarray:
    """x_chw: [3,H,W] normalized; return HxWx3 in [0,1] numpy."""
    mean = torch.tensor(IMAGENET_MEAN, device=x_chw.device).view(3,1,1)
    std  = torch.tensor(IMAGENET_STD,  device=x_chw.device).view(3,1,1)
    x01 = (x_chw * std + mean).clamp(0,1)
    return x01.permute(1,2,0).detach().cpu().numpy()

# ---------------------------
# Subset + basic stats (r, v_hat, c)
# ---------------------------

def make_subdataset(dataset, max_images: int, seed: int = 0):
    n = len(dataset)
    rng = np.random.default_rng(seed)
    idx = rng.choice(n, size=min(max_images, n), replace=False)
    return Subset(dataset, sorted(idx.tolist()))

@torch.no_grad()
def estimate_average_distance(dataset, device="cpu", max_images: int = 2048, batch_size: int = 64) -> float:
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=False)
    xs, total = [], 0
    for imgs, _ in loader:
        imgs = imgs.to(device)
        xs.append(imgs)
        total += imgs.size(0)
        if total >= max_images:
            break
    X = torch.cat(xs, dim=0)     # [N,C,H,W]
    N = X.size(0)
    perm = torch.randperm(N, device=X.device)
    perm = torch.roll(perm, shifts=1)
    dists = torch.norm((X - X[perm]).view(N, -1), p=2, dim=1)
    return float(dists.mean().item())

@torch.no_grad()
def estimate_per_pixel_variance_global_mean(dataset, device="cpu", max_images: int = 2048, batch_size: int = 64) -> Tuple[float,int]:
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=False)
    xs, total = [], 0
    for imgs, _ in loader:
        imgs = imgs.to(device)
        xs.append(imgs)
        total += imgs.size(0)
        if total >= max_images:
            break
    X = torch.cat(xs, dim=0)   # [N,C,H,W]
    N, C, H, W = X.shape
    d = C*H*W
    Xv = X.view(N, -1)
    mu = Xv.mean()                                # GLOBAL mean
    v_hat = ((Xv - mu)**2).mean().item()
    return v_hat, d

def compute_c_factor(r: float, v_hat: float, d: int) -> float:
    return (r*r) / (2.0 * d * v_hat + 1e-20)

# ---------------------------
# Mixup + noise + SNR
# ---------------------------

def add_noise_with_l2_norm_batch(x: torch.Tensor, target_norm: float) -> torch.Tensor:
    B = x.size(0)
    noise = torch.randn_like(x)
    norms = torch.norm(noise.view(B, -1), p=2, dim=1)
    scales = (target_norm / (norms + 1e-12)).view(B, 1, 1, 1)
    return x + noise * scales

def theoretical_mf_from_tau(alpha: float, tau: float, c: float) -> float:
    # mf >= sqrt( (1/(2c)) * (alpha^2/(tau(1-alpha)^2) - 1) )
    a2 = (alpha**2)
    inner = max(a2 / (max(tau,1e-20) * (1-alpha)**2) - 1.0, 0.0)
    return float(np.sqrt(inner / (2.0 * max(c,1e-20))))

def recovery_snr_db(x: torch.Tensor, xhat: torch.Tensor) -> torch.Tensor:
    """Per-image SNR in dB: 10 log10( ||x||^2 / ||x - xhat||^2 )"""
    B = x.size(0)
    xv = x.view(B, -1)
    yv = xhat.view(B, -1)
    sig = torch.sum(xv**2, dim=1)
    err = torch.sum((xv - yv)**2, dim=1) + 1e-20
    snr = sig / err
    return 10.0 * torch.log10(snr)

# ---------------------------
# Build a mixup graph + observed mixtures
# ---------------------------

@torch.no_grad()
def build_mixup_observations(X: torch.Tensor, alpha: float, r: float, mf: float, device="cpu"):
    """
    X: [N,C,H,W] normalized clean images of the subset
    Returns:
      Y: [N,C,H,W] mixed observations
      partner_idx: LongTensor [N] partner index for each i
    Construction: Y[i] = alpha * X[i] + (1-alpha) * (X[j] + noise_ij), ||noise_ij|| = mf * r
    """
    N = X.size(0)
    perm = torch.randperm(N, device=device)
    if (perm == torch.arange(N, device=device)).any():
        perm = torch.roll(perm, shifts=1)
    partner = perm.clone()

    target_norm = float(mf) * float(r)
    noisy_partner = add_noise_with_l2_norm_batch(X[partner], target_norm=target_norm)
    Y = alpha * X + (1.0 - alpha) * noisy_partner
    return Y, partner

# ---------------------------
# TV loss (anisotropic)
# ---------------------------

def tv_loss(x: torch.Tensor) -> torch.Tensor:
    """
    Anisotropic total variation: sum |x[:, :, :, 1:] - x[:, :, :, :-1]| + |x[:, :, 1:, :] - x[:, :, :-1, :]|
    Return mean over batch.
    """
    dh = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
    dw = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
    return (dh.mean() + dw.mean())

# ---------------------------
# Linear attack optimizer
# ---------------------------

def run_linear_attack(
    X_clean: torch.Tensor, Y_obs: torch.Tensor, partner_idx: torch.Tensor,
    alpha: float, steps: int = 200, lr: float = 0.05,
    lambda_tv: float = 1e-3, lambda_l2: float = 1e-4,
) -> torch.Tensor:
    """
    Recover variables V (clean images) minimizing:
      MSE(Y, alpha*V + (1-alpha)*V[partner]) + lambda_tv*TV(V) + lambda_l2*||V||^2
    Clamp to plausible range after each optimizer step.
    """
    V = Y_obs.clone().detach().requires_grad_(True)
    opt = torch.optim.Adam([V], lr=lr)
    mse = nn.MSELoss()

    for t in range(steps):
        opt.zero_grad()
        V_partner = V[partner_idx]
        recon = alpha * V + (1.0 - alpha) * V_partner
        loss_recon = mse(recon, Y_obs)
        loss_tv = lambda_tv * tv_loss(V)
        loss_l2 = lambda_l2 * (V.pow(2).mean())
        loss = loss_recon + loss_tv + loss_l2
        loss.backward()
        opt.step()
        with torch.no_grad():
            V.copy_(clamp_imagenet_normalized(V))
        if (t+1) % max(steps//5, 1) == 0:
            print(f"    [attack] step {t+1:4d}  L={loss.item():.6f}  mse={loss_recon.item():.6f}  tv={loss_tv.item():.6f}  l2={loss_l2.item():.6f}")
    return V.detach()

# ---------------------------
# Visualization (single multi-τ grid: Original / Mixup / Recovered)
# ---------------------------

def save_multi_tau_grid(
    datasets_in_order: List[str],
    taus: List[float],
    per_dataset_results: Dict[str, Dict],
    out_dir: str,
    fname: str = "multi_dataset_multi_tau.png",
):
    """
    Build one big grid figure.
    For each dataset -> three rows (Original, Mixup, Recovered) across τ columns.
    Only τ is shown above columns; no other text/labels.
    """
    os.makedirs(out_dir, exist_ok=True)
    n_cols = len(taus)
    n_rows = 3 * len(datasets_in_order)

    fig_w = max(3 * n_cols, 6)
    fig_h = max(2.5 * n_rows, 6)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_w, fig_h))
    if n_rows == 1:
        axes = np.expand_dims(axes, axis=0)
    if n_cols == 1:
        axes = np.expand_dims(axes, axis=1)

    # Put τ titles ONLY on the top row
    for j, tau in enumerate(taus):
        axes[0, j].set_title(tau_to_latex(tau), fontsize=30)

    for dsi, ds_name in enumerate(datasets_in_order):
        ds_res = per_dataset_results[ds_name]
        chosen_idx = ds_res["best_index_at_max_tau"]

        top_row = 3 * dsi       # original
        mid_row = top_row + 1   # mixup
        bot_row = top_row + 2   # recovered

        # Prepare the original once (same for all τ)
        orig_img_np = chw_to_numpy_img01(ds_res["X_clean"][chosen_idx])

        for j, tau in enumerate(taus):
            # Mixup and Recovered for this τ
            Yobs = ds_res["per_tau"][tau]["Y_obs"][chosen_idx]
            Xrec = ds_res["per_tau"][tau]["X_rec"][chosen_idx]

            mix_np = chw_to_numpy_img01(Yobs)
            rec_np = chw_to_numpy_img01(Xrec)

            # ORIGINAL
            ax = axes[top_row, j]
            ax.imshow(orig_img_np); ax.axis("off")

            # MIXUP
            ax = axes[mid_row, j]
            ax.imshow(mix_np); ax.axis("off")

            # RECOVERED
            ax = axes[bot_row, j]
            ax.imshow(rec_np); ax.axis("off")

    plt.tight_layout()
    out_path = os.path.join(out_dir, fname)
    plt.savefig(out_path, dpi=150, bbox_inches="tight")
    plt.close()
    print(f"[saved grid] {out_path}")

# ---------------------------
# Main experiment
# ---------------------------

def run_experiment(
    datasets: List[str], taus: List[float], alpha: float, sub_size: int,
    device: str, seed: int,
    max_images_for_stats: int, batch_size_stats: int,
    attack_steps: int, attack_lr: float, lambda_tv: float, lambda_l2: float,
    out_dir: str,
):
    if seed is not None:
        random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)

    print(f"\n=== Linear attack on mixup (known graph) ===")
    print(f"alpha    : {alpha}")
    print(f"taus     : {taus}")
    print(f"subset   : {sub_size} images per dataset")
    print(f"device   : {device}")
    print(f"attack   : steps={attack_steps}, lr={attack_lr}, tv={lambda_tv}, l2={lambda_l2}\n")

    # per_dataset_results layout:
    # per_dataset_results[ds_name] = {
    #   "X_clean": Tensor [N,C,H,W],
    #   "per_tau": { tau_value: {"X_rec": Tensor, "Y_obs": Tensor, "snr_db": Tensor[N]} },
    #   "max_tau": float,
    #   "best_index_at_max_tau": int
    # }
    per_dataset_results: Dict[str, Dict] = {}

    for ds_name in datasets:
        print(f"\n--- Dataset: {ds_name} ---")
        _, test_ds = get_dataset(ds_name)
        sub_ds = make_subdataset(test_ds, max_images=sub_size, seed=seed or 0)

        # Load the subset into one tensor X_clean
        loader = DataLoader(sub_ds, batch_size=batch_size_stats, shuffle=False, num_workers=2)
        X_list = [imgs.to(device) for imgs, _ in loader]
        X_clean = torch.cat(X_list, dim=0)  # [N,C,H,W]
        N, C, H, W = X_clean.shape
        print(f"[subset] N={N}, C={C}, H={H}, W={W}")

        # Dataset stats on the same subset
        r = estimate_average_distance(sub_ds, device=device, max_images=max_images_for_stats, batch_size=batch_size_stats)
        v_hat, d = estimate_per_pixel_variance_global_mean(sub_ds, device=device, max_images=max_images_for_stats, batch_size=batch_size_stats)
        c = compute_c_factor(r, v_hat, d)
        print(f"[stats] r={r:.6f} | d={d} | v_hat={v_hat:.4e} | c={c:.4f}")

        ds_store = {
            "X_clean": X_clean.detach(),
            "per_tau": {},
            "max_tau": max(taus) if len(taus) > 0 else None,
            "best_index_at_max_tau": None
        }

        for tau in taus:
            mf = theoretical_mf_from_tau(alpha=alpha, tau=tau, c=c)
            print(f"\n  tau={tau:.6g} -> mf(theory)={mf:.6f} (target ||n|| = {mf*r:.2f})")

            with torch.no_grad():
                Y_obs, partner = build_mixup_observations(X_clean, alpha=alpha, r=r, mf=mf, device=device)

            X_rec = run_linear_attack(
                X_clean, Y_obs, partner, alpha=alpha,
                steps=attack_steps, lr=attack_lr, lambda_tv=lambda_tv, lambda_l2=lambda_l2
            )

            snr_db = recovery_snr_db(X_clean, X_rec)
            mean_db = float(snr_db.mean().item())
            std_db  = float(snr_db.std().item())
            print(f"  [result] avg recovery SNR = {mean_db:.2f} ± {std_db:.2f} dB over {N} images")

            ds_store["per_tau"][tau] = {
                "X_rec": X_rec.detach(),
                "Y_obs": Y_obs.detach(),
                "snr_db": snr_db.detach().cpu()
            }

        # Choose best index at the largest τ
        if len(taus) > 0:
            tau_max = ds_store["max_tau"]
            snr_max_tau = ds_store["per_tau"][tau_max]["snr_db"]
            best_idx = int(torch.argmax(snr_max_tau).item())
            ds_store["best_index_at_max_tau"] = best_idx
            print(f"  [select] best image index at max τ={tau_max:g}: {best_idx} "
                  f"(SNR={snr_max_tau[best_idx].item():.2f} dB)")

        per_dataset_results[ds_name] = ds_store

    # Final single figure
    desired_order = ["mnist", "cifar10", "cifar100", "tiny-imagenet"]
    datasets_in_order = [ds for ds in desired_order if ds in per_dataset_results]

    if len(datasets_in_order) == 0 or len(taus) == 0:
        print("[warn] No datasets or no taus provided; skipping figure.")
        return

    save_multi_tau_grid(
        datasets_in_order=datasets_in_order,
        taus=taus,
        per_dataset_results=per_dataset_results,
        out_dir=out_dir,
        fname="multi_dataset_multi_tau.png"
    )

# ---------------------------
# CLI
# ---------------------------

def parse_args():
    p = argparse.ArgumentParser(
        description="Simulate linear reconstruction attack on mixup (known mixing graph, TV+L2 priors)."
    )
    p.add_argument("--datasets", type=str, nargs="+",
                   default=["mnist", "cifar10", "cifar100", "tiny-imagenet"])
    p.add_argument("--taus", type=float, nargs="+",
                   default=[1e-00, 1e-01, 1e-02, 1e-03, 1e-04, 1e-05, 1e-06])
    p.add_argument("--alpha", type=float, default=0.7)

    p.add_argument("--sub_size", type=int, default=512,
                   help="Subset size per dataset (opt scales with this).")
    p.add_argument("--device", type=str, default="cpu")
    p.add_argument("--seed", type=int, default=0)

    # Stats
    p.add_argument("--max_images_for_stats", type=int, default=1024)
    p.add_argument("--batch_size_stats", type=int, default=64)

    # Attack optimizer
    p.add_argument("--attack_steps", type=int, default=200)
    p.add_argument("--attack_lr", type=float, default=0.05)
    p.add_argument("--lambda_tv", type=float, default=1e-3)
    p.add_argument("--lambda_l2", type=float, default=1e-4)

    p.add_argument("--out_dir", type=str, default="attack_figs")

    # Logging
    p.add_argument("--log_file", type=str, default=None,
                   help="If set, write logs to this file (in addition to stdout). "
                        "Default: attack_figs/run_YYYYmmdd_HHMMSS.log")
    return p.parse_args()

# ---------------------------
# Entry
# ---------------------------

if __name__ == "__main__":
    args = parse_args()

    # Ensure output dir exists to place logs/figures
    os.makedirs(args.out_dir, exist_ok=True)

    # Choose log file path
    if args.log_file is None:
        stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_path = os.path.join(args.out_dir, f"run_{stamp}.log")
    else:
        log_path = args.log_file
        os.makedirs(os.path.dirname(os.path.abspath(log_path) or "."), exist_ok=True)

    # Install Tee for stdout + stderr
    _orig_stdout, _orig_stderr = sys.stdout, sys.stderr
    _log_fh = open(log_path, mode="w", buffering=1, encoding="utf-8")  # line-buffered
    sys.stdout = Tee(_orig_stdout, _log_fh)
    sys.stderr = Tee(_orig_stderr, _log_fh)

    try:
        print(f"[log] Logging to: {os.path.abspath(log_path)}")
        print(f"[env] Python: {sys.version.splitlines()[0]}")
        print(f"[env] Torch: {torch.__version__} | CUDA available: {torch.cuda.is_available()}")
        print(f"[args] {args}")

        device = torch.device(args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu")
        run_experiment(
            datasets=args.datasets,
            taus=args.taus,
            alpha=args.alpha,
            sub_size=args.sub_size,
            device=device,
            seed=args.seed,
            max_images_for_stats=args.max_images_for_stats,
            batch_size_stats=args.batch_size_stats,
            attack_steps=args.attack_steps,
            attack_lr=args.attack_lr,
            lambda_tv=args.lambda_tv,
            lambda_l2=args.lambda_l2,
            out_dir=args.out_dir,
        )
        print(f"[done] Finished. Full log saved at: {os.path.abspath(log_path)}")
    finally:
        # Restore std streams and close file handle
        try:
            sys.stdout.flush(); sys.stderr.flush()
        except Exception:
            pass
        sys.stdout, sys.stderr = _orig_stdout, _orig_stderr
        try:
            _log_fh.close()
        except Exception:
            pass
