import os, math, time, argparse, statistics, pathlib, random, csv
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
import matplotlib.pyplot as plt

# ----------------------------
# I/O & device
# ----------------------------
OUTDIR = "results_1"
FIGDIR = f"{OUTDIR}/figs"
def ensure_dir(p: str): pathlib.Path(p).mkdir(parents=True, exist_ok=True)
ensure_dir(OUTDIR); ensure_dir(FIGDIR)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def set_seed(s=0):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)

# ----------------------------
# Data & tasks
# ----------------------------
def cheb_points(n: int, device=device):
    k = torch.arange(n, device=device)
    return torch.cos((2*k + 1) * math.pi / (2*n))

def linspace_points(n: int, device=device):
    return torch.linspace(-1.0, 1.0, steps=n, device=device)

def f_runge(x): return 1.0 / (1.0 + 25.0 * (x**2))
def f_abs(x):   return torch.abs(x)
def f_abs_alpha(x, alpha: float): return torch.abs(x)**alpha

TASKS = {"runge": f_runge, "abs": f_abs, "abs_alpha": None}

def make_data(task: str, *, ntr=2048, nva=4096, device=device, alpha: float = 0.5):
    xtr = cheb_points(ntr, device=device).unsqueeze(1)
    xva = linspace_points(nva, device=device).unsqueeze(1)
    ftrue = (lambda x: f_abs_alpha(x, alpha)) if task=="abs_alpha" else TASKS[task]
    return (xtr, ftrue(xtr)), (xva, ftrue(xva))

# ----------------------------
# Models
# ----------------------------
def ensure_col(x: torch.Tensor) -> torch.Tensor:
    return x.unsqueeze(1) if x.ndim==1 else x

# class CauchyAct(nn.Module):
#     """
#     φ(z) = (l1*z + l2) / (z^2 + d^2) + γ z
#     l1,l2 = tanh(raw_*), d = softplus(raw_d)+eps.
#     """
#     def __init__(self, dim: int, d0: float = 0.3, eps: float = 1e-6,
#                  l1_init=0.8, l2_init=0.0, gamma0: float = 0.1):
#         super().__init__()
#         l1_init = float(np.clip(l1_init, -0.99, 0.99))
#         l2_init = float(np.clip(l2_init, -0.99, 0.99))
#         self.raw_l1 = nn.Parameter(torch.full((dim,), float(np.arctanh(l1_init))))
#         self.raw_l2 = nn.Parameter(torch.full((dim,), float(np.arctanh(l2_init))))
#         self.raw_d  = nn.Parameter(torch.full((dim,), math.log(math.expm1(d0))))
#         self.gamma  = nn.Parameter(torch.full((dim,), gamma0))
#         self.eps = eps

#     @property
#     def l1(self): return torch.tanh(self.raw_l1)
#     @property
#     def l2(self): return torch.tanh(self.raw_l2)
#     @property
#     def d (self): return F.softplus(self.raw_d) + self.eps

#     def forward(self, z):
#         # compute in float32 for stability (esp. under AMP), then cast back
#         z32 = z.float()
#         l1, l2, g, d = self.l1.float(), self.l2.float(), self.gamma.float(), self.d.float()
#         denom = z32*z32 + d*d
#         denom = torch.clamp(denom, min=1e-12)
#         y32 = (l1 * z32 + l2) / denom + g * z32
#         return y32.to(z.dtype)
class CauchyAct(nn.Module):
    """
    φ(z) = (l1*z + l2) / (z^2 + d^2)
    l1,l2 = tanh(raw_*); d is a directly learned scalar per channel.
    """
    def __init__(self, dim: int, d0: float = 0.3, eps: float = 1e-6,
                 l1_init: float = 0.8, l2_init: float = 0.0):
        super().__init__()
        l1_init = float(np.clip(l1_init, -0.99, 0.99))
        l2_init = float(np.clip(l2_init, -0.99, 0.99))
        self.raw_l1 = nn.Parameter(torch.full((dim,), float(np.arctanh(l1_init))))
        self.raw_l2 = nn.Parameter(torch.full((dim,), float(np.arctanh(l2_init))))
        self.d      = nn.Parameter(torch.full((dim,), float(max(d0, 1e-6))))
        self.eps2   = float(eps * eps)

    @property
    def l1(self): return torch.tanh(self.raw_l1)
    @property
    def l2(self): return torch.tanh(self.raw_l2)

    def forward(self, z):
        z32 = z.float()
        l1, l2 = self.l1.float(), self.l2.float()
        d2 = (self.d.float() * self.d.float()) + self.eps2  # >0
        denom = torch.clamp(z32 * z32 + d2, min=1e-12)
        y32 = (l1 * z32 + l2) / denom
        return y32.to(z.dtype)

class XNetOneLayer(nn.Module):
    """One-hidden-layer with Cauchy activation: 1 -> W -> φ_cauchy -> 1."""
    def __init__(self, width=64, d0=0.5, in_dim=1, out_dim=1):
        super().__init__()
        self.lin  = nn.Linear(in_dim, width)
        self.act  = CauchyAct(width, d0=d0)
        self.head = nn.Linear(width, out_dim)
        nn.init.xavier_normal_(self.lin.weight);  nn.init.zeros_(self.lin.bias)
        nn.init.xavier_normal_(self.head.weight); nn.init.zeros_(self.head.bias)
        self.width = width
    def forward(self, x, collect_cauchy_z=False):
        h = ensure_col(x)
        z = self.lin(h)
        z_list = [z] if collect_cauchy_z else []
        h = self.act(z)
        return self.head(h), z_list

class PAU(nn.Module):
    """A tiny Padé-style activation: (a1*x + a3*x^3) / (1 + sp(b2)*x^2)"""
    def __init__(self, a1=1.0, a3=0.0, b2=0.5, eps=1e-6):
        super().__init__()
        self.a1 = nn.Parameter(torch.tensor(float(a1)))
        self.a3 = nn.Parameter(torch.tensor(float(a3)))
        self.b2 = nn.Parameter(torch.tensor(float(b2)))
        self.eps = eps
    def forward(self, x):
        num = self.a1 * x + self.a3 * (x * x * x)
        den = 1.0 + F.softplus(self.b2) * (x * x) + self.eps
        return num / den

def get_activation(name: str):
    name = name.lower()
    return {
        "tanh": nn.Tanh(), "relu": nn.ReLU(), "gelu": nn.GELU(),
        "silu": nn.SiLU(), "softplus": nn.Softplus(), "elu": nn.ELU(),
        "pau": PAU(),
    }[name]



class MLP(nn.Module):
    """Generic MLP with selectable activation."""
    def __init__(self, width=64, depth=1, act_name="tanh", in_dim=1, out_dim=1):
        super().__init__()
        self.act = get_activation(act_name)
        self.lins = nn.ModuleList([nn.Linear(in_dim, width)])
        for _ in range(depth-1): self.lins.append(nn.Linear(width, width))
        self.head = nn.Linear(width, out_dim)
        for lin in self.lins:
            nn.init.xavier_normal_(lin.weight); nn.init.zeros_(lin.bias)
        nn.init.xavier_normal_(self.head.weight); nn.init.zeros_(self.head.bias)
        self.width, self.depth, self.act_name = width, depth, act_name
    def forward(self, x, collect_cauchy_z=False):
        h = ensure_col(x)
        for lin in self.lins: h = self.act(lin(h))
        return self.head(h), []

class MLPCauchyHead(nn.Module):
    """
    MLP trunk (depth-1 hidden layers with chosen activation) -> CauchyAct -> Linear head.
    If depth==1: no trunk layers; behaves similarly to XNetOneLayer but with possibly different init.
    """
    def __init__(self, width=64, depth=2, act_name="tanh", d0=0.5, in_dim=1, out_dim=1):
        super().__init__()
        assert depth >= 1, "depth must be >=1 (depth=1 means no MLP trunk, only Cauchy head)"
        trunk_layers = max(depth-1, 0)
        self.act_name = act_name
        self.act_trunk = get_activation(act_name)
        self.trunk = nn.ModuleList([])
        fin = in_dim
        for _ in range(trunk_layers):
            self.trunk.append(nn.Linear(fin, width))
            fin = width
        # If trunk is empty, make a linear to width to produce z
        self.pre = nn.Linear(fin, width)
        self.cauchy = CauchyAct(width, d0=d0)
        self.head = nn.Linear(width, out_dim)
        # inits
        for lin in self.trunk:
            nn.init.xavier_normal_(lin.weight); nn.init.zeros_(lin.bias)
        nn.init.xavier_normal_(self.pre.weight);  nn.init.zeros_(self.pre.bias)
        nn.init.xavier_normal_(self.head.weight); nn.init.zeros_(self.head.bias)
        self.width, self.depth = width, depth

    def forward(self, x, collect_cauchy_z=False):
        h = ensure_col(x)
        for lin in self.trunk:
            h = self.act_trunk(lin(h))
        z = self.pre(h)
        z_list = [z] if collect_cauchy_z else []
        h = self.cauchy(z)
        return self.head(h), z_list

# ----------------------------
# Utils
# ----------------------------
def count_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def solve_width_for_params(target_params:int, arch:str, depth:int, act_name:str, d0:float) -> int:
    """Binary search width for ANY arch so that params ~= target (>=)."""
    lo, hi, best = 1, 4096, None
    while lo <= hi:
        mid = (lo+hi)//2
        if arch == "xnet":
            model = XNetOneLayer(width=mid, d0=d0)
        elif arch.startswith("mlp-") and arch != "mlp-cauchy-head":
            model = MLP(width=mid, depth=depth, act_name=act_name)
        elif arch == "mlp-cauchy-head":
            model = MLPCauchyHead(width=mid, depth=depth, act_name=act_name, d0=d0)
        else:
            raise ValueError(f"Unknown arch for param solve: {arch}")
        p = count_params(model)
        if p >= target_params: best, hi = mid, mid-1
        else: lo = mid+1
    return best or lo

@torch.no_grad()
def val_l2(model, xva, yva):
    pred, _ = model(xva, collect_cauchy_z=False)
    return F.mse_loss(pred, yva).item()

def has_cauchy_component(model: nn.Module) -> bool:
    return isinstance(model, XNetOneLayer) or isinstance(model, MLPCauchyHead)

def slope_penalty_cauchy(model: nn.Module, slope_cap=0.4, weight=1e-2):
    if weight <= 0: 
        return torch.tensor(0.0, device=device)
    if isinstance(model, XNetOneLayer):
        l1, d = model.act.l1, model.act.d
    elif isinstance(model, MLPCauchyHead):
        l1, d = model.cauchy.l1, model.cauchy.d
    else:
        return torch.tensor(0.0, device=device)
    s = torch.abs(l1) / (d**2)
    return weight * torch.mean(F.relu(s - slope_cap)**2)

def fit_rate(xs, errs, root=False):
    xs_arr = np.array(xs, float)
    xs_fit = np.sqrt(xs_arr) if root else xs_arr
    ys = np.log(np.array(errs, float) + 1e-30)
    A = np.vstack([xs_fit, np.ones_like(xs_fit)]).T
    a, b = np.linalg.lstsq(A, ys, rcond=None)[0]
    r2 = 1 - np.sum((ys-(a*xs_fit+b))**2) / (np.sum((ys-ys.mean())**2) + 1e-12)
    return float(a), float(b), float(r2)

# ----------------------------
# Train / Run
# ----------------------------
def parse_arch(arch: str):
    """Returns (kind, act_name). kind in {'xnet','mlp','mlp-cauchy-head'}"""
    arch = arch.lower()
    if arch == "xnet":
        return "xnet", None
    if arch == "mlp-cauchy-head":
        # default trunk act is tanh unless overridden elsewhere
        return "mlp-cauchy-head", "tanh"
    if arch.startswith("mlp-"):
        return "mlp", arch.split("-",1)[1]
    raise ValueError(f"Unknown arch string: {arch}")

def build_model(arch: str, width: int, depth: int, d0: float, act_name="tanh"):
    kind, act = parse_arch(arch)
    if kind == "xnet": 
        return XNetOneLayer(width=width, d0=d0).to(device)
    if kind == "mlp":
        return MLP(width=width, depth=depth, act_name=act).to(device)
    if kind == "mlp-cauchy-head":
        # use provided act_name if given, else 'act' from parse_arch (default tanh)
        act_use = act_name or act or "tanh"
        return MLPCauchyHead(width=width, depth=depth, act_name=act_use, d0=d0).to(device)
    raise ValueError(f"Unknown arch kind: {arch}")

def default_target_for_task(task: str) -> float:
    if task == "runge": return 1e-6
    if task == "abs_alpha": return 1e-4
    if task == "abs": return 1e-3
    return 1e-4

def train_one(task, arch, width, depth, d0, alpha, iters, lr, wd, eval_every, time_budget, amp, seed,
              slope_cap=0.4, slope_weight=1e-2, target_l2=None):
    set_seed(seed)
    (xtr, ytr), (xva, yva) = make_data(task, device=device, alpha=alpha)
    # Determine act_name for trunk if needed
    kind, act_name = parse_arch(arch)
    model = build_model(arch, width, depth, d0, act_name=act_name if kind!="xnet" else None)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    scaler = torch.amp.GradScaler('cuda', enabled=amp)

    best_l2, best_state = float("inf"), None
    t_cross = None
    t0 = time.perf_counter()
    anneal_start = int(args.anneal_start_frac * iters)

    for it in range(1, iters+1):
        opt.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda', enabled=amp):
            pred, z_list = model(xtr, collect_cauchy_z=has_cauchy_component(model))
            loss = F.mse_loss(pred, ytr)
            loss = loss + slope_penalty_cauchy(model, slope_cap, slope_weight)
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(opt); scaler.update()

        # Optional anneal on raw_d late in training for cauchy models
        if has_cauchy_component(model) and args.anneal_d and it >= anneal_start:
            with torch.no_grad():
                fac = args.anneal_factor
                if isinstance(model, XNetOneLayer):
                    model.act.raw_d.add_(math.log(fac))
                else:  # MLPCauchyHead
                    model.cauchy.raw_d.add_(math.log(fac))
                    
        need_eval = (it==1) or (it % eval_every == 0)
        overtime = (time_budget is not None) and (time.perf_counter()-t0 >= time_budget)
        if need_eval or overtime or it==iters:
            l2 = val_l2(model, xva, yva)
            if target_l2 is not None and t_cross is None and l2 <= target_l2:
                t_cross = time.perf_counter() - t0
            if l2 < best_l2:
                best_l2 = l2
                best_state = {k: v.detach().cpu().clone() for k,v in model.state_dict().items()}
            print(f"{task:<9} | {arch:<16} | w={width:4d} | d={depth:2d} | "
                f"seed={seed} | it={it:4d} | L2={l2:.2e} | best={best_l2:.2e}")
        if overtime: break

    if best_state is not None: model.load_state_dict(best_state)
    elapsed = time.perf_counter()-t0
    return val_l2(model, xva, yva), elapsed, count_params(model), t_cross

def run_suite(args):
    torch.set_float32_matmul_precision("high")
    print(f"Device={device}, AMP={bool(args.amp)}")

    results = []  # dict(task, arch, depth, seed, width, params, L2, time, anchor_w, anchor_params, time_to_thresh)
    target_l2 = args.target_l2 or default_target_for_task(args.task)
    print(f"Target L2 for thresholding: {target_l2}")

    for w_anchor in args.widths:
        # num of para of XNet: anchor
        px = count_params(build_model("xnet", w_anchor, depth=1, d0=args.d0))

        for arch in args.archs:
            kind, act_name = parse_arch(arch)
            if arch == "xnet":
                widths_this = [w_anchor]
                depths_this = [1]
            else:
                depths_this = args.mlp_depths if args.mlp_depths else [args.mlp_depth]
                if args.param_matched:
                    widths_this = [solve_width_for_params(px, arch, depth=d, act_name=act_name or "tanh", d0=args.d0) for d in depths_this]
                else:
                    widths_this = args.widths * 1  # same widths for all depths

            combos = [(1, w_anchor)] if arch=="xnet" else list(zip(depths_this, widths_this))
            for depth, ww in combos:
                for s in args.seeds:
                    l2, dt, p, t_cross = train_one(
                        task=args.task, arch=arch, width=ww, depth=depth,
                        d0=args.d0, alpha=args.alpha, iters=args.iters, lr=args.lr, wd=args.wd,
                        eval_every=args.eval_every, time_budget=args.time_budget, amp=bool(args.amp),
                        seed=s, slope_cap=args.slope_cap, slope_weight=args.slope_weight, target_l2=target_l2
                    )
                    results.append({"task": args.task, "arch": arch, "depth": depth, "seed": s,
                                    "width": ww, "params": p, "L2": l2, "time": dt,
                                    "anchor_w": w_anchor, "anchor_params": px, "time_to_thresh": t_cross})

    #
    csv_path = f"{OUTDIR}/{args.task}_all.csv"
    with open(csv_path, "w", newline="") as f:
        wtr = csv.DictWriter(f, fieldnames=["task","arch","depth","seed","width","params","L2","time","anchor_w","anchor_params","time_to_thresh"])
        wtr.writeheader(); wtr.writerows(results)
    print(f"Saved CSV -> {csv_path}")

    # ---------- （time-to-threshold & params-to-threshold） ----------
    # for arch：min para & min time
    rows = []
    by_arch = {}
    for r in results:
        by_arch.setdefault(r["arch"], []).append(r)
    for arch, lst in by_arch.items():
        # min params among runs that reach threshold
        reach = [r for r in lst if r["L2"] <= target_l2 or (r["time_to_thresh"] is not None)]
        if reach:
            min_params = min(r["params"] for r in reach)
            # For time-to-threshold, we need the earliest crossing time
            t_candidates = [r["time_to_thresh"] for r in reach if r["time_to_thresh"] is not None]
            min_time = min(t_candidates) if t_candidates else None
        else:
            min_params, min_time = None, None
        rows.append({"arch": arch, "min_params@thr": min_params, "min_time@thr": min_time})
    thr_csv = f"{OUTDIR}/{args.task}_threshold_table.csv"
    with open(thr_csv, "w", newline="") as f:
        wtr = csv.DictWriter(f, fieldnames=["arch","min_params@thr","min_time@thr"])
        wtr.writeheader(); wtr.writerows(rows)
    print(f"Saved threshold table -> {thr_csv}")

    # ---------- fig ----------
    # A) error vs width（param-matched anchor_w；for MLP: best-of-depth）
    if args.param_matched:
        x_key = "anchor_w"
        x_label = "XNet width anchor"
        title_suffix = "param-matched (by XNet width anchor)"
    else:
        x_key = "width"
        x_label = "width"
        title_suffix = "same widths"

    #
    med_arch_depth = {}  # arch -> depth -> {x: [errs]}
    for r in results:
        med_arch_depth.setdefault(r["arch"], {}).setdefault(r["depth"], {}).setdefault(r[x_key], []).append(r["L2"])

    xs_sorted = sorted({r[x_key] for r in results})

    # each arch  best-of-depth（ for each x: min median）
    best_of_depth = {}  # arch -> [meds over xs]
    for arch, depth_map in med_arch_depth.items():
        depth_curves = {}
        for d, xmap in depth_map.items():
            depth_curves[d] = [statistics.median(xmap[x]) if x in xmap else float("inf") for x in xs_sorted]

        best = []
        for i, x in enumerate(xs_sorted):
            best.append(min(depth_curves[d][i] for d in depth_curves))
        best_of_depth[arch] = best

    #  overlay：XNet（ d=1）+  MLP  best-of-depth
    plt.figure(figsize=(6.6,4.2))
    for arch in args.archs:
        meds = best_of_depth[arch]
        label = arch if arch=="xnet" else f"{arch} (best-of-depth)"
        plt.semilogy(xs_sorted, meds, marker="o", label=label)
    plt.xlabel(x_label); plt.ylabel("median L2 (val)")
    plt.title(f"{args.task} | {title_suffix}")
    plt.grid(True, which="both", alpha=0.3); plt.legend()
    fig_path = f"{FIGDIR}/{args.task}_overlay.png"
    plt.tight_layout(); plt.savefig(fig_path, dpi=220); plt.close()
    print(f"Saved FIG -> {fig_path}")

    # B) rate（ best-of-depth）
    if args.rate_plots:
        for mode in ["M","sqrtM"]:
            plt.figure(figsize=(6.6,4.2))
            root = (mode=="sqrtM")
            for arch in args.archs:
                meds = best_of_depth[arch]
                a, b, r2 = fit_rate(xs_sorted, meds, root=root)
                xs_plot = np.sqrt(xs_sorted) if root else xs_sorted
                plt.semilogy(xs_plot, meds, marker="o", label=f"{arch} (a={a:.3f}, R2={r2:.3f})")
            plt.xlabel(("sqrt("+x_label+")") if root else x_label)
            plt.ylabel("val L2 (log scale)")
            plt.title(f"{args.task} | rate id: {mode} | {title_suffix}")
            plt.grid(True, which="both", alpha=0.3); plt.legend(fontsize=9)
            out = f"{FIGDIR}/{args.task}_rate_{mode}.png"
            plt.tight_layout(); plt.savefig(out, dpi=220); plt.close()
            print(f"Saved FIG -> {out}")

    # C) error vs num of para（best-of-depth，median）
    #  (arch, depth) -> params -> [errs]
    by_arch_depth_params = {}
    for r in results:
        by_arch_depth_params.setdefault(r["arch"], {}).setdefault(r["depth"], {}).setdefault(r["params"], []).append(r["L2"])

    best_by_params = {}
    for arch, depth_map in by_arch_depth_params.items():
        bucket = {}
        all_params = sorted({p for d in depth_map for p in depth_map[d].keys()})
        for p in all_params:
            cand = []
            for d in depth_map:
                if p in depth_map[d]:
                    cand.append(statistics.median(depth_map[d][p]))
            if cand:
                bucket[p] = min(cand)
        best_by_params[arch] = bucket

    plt.figure(figsize=(6.6,4.2))
    for arch in args.archs:
        xs = sorted(best_by_params[arch].keys())
        ys = [best_by_params[arch][p] for p in xs]
        label = arch if arch=="xnet" else f"{arch} (best-of-depth)"
        plt.semilogy(xs, ys, marker="o", label=label)
    plt.xlabel("#params (trainable)")
    plt.ylabel("median L2 (val)")
    plt.title(f"{args.task} | error vs params (best-of-depth)")
    plt.grid(True, which="both", alpha=0.3); plt.legend()
    out2 = f"{FIGDIR}/{args.task}_err_vs_params.png"
    plt.tight_layout(); plt.savefig(out2, dpi=220); plt.close()
    print(f"Saved FIG -> {out2}")

# ----------------------------
# CLI
# ----------------------------
def build_argparser():
    p = argparse.ArgumentParser()
    p.add_argument("--task", type=str, default="abs", choices=list(TASKS.keys()))
    p.add_argument("--archs", type=str, nargs="+",
                   default=["xnet","mlp-tanh","mlp-relu","mlp-gelu","mlp-pau","mlp-cauchy-head"],
                   help="xnet=one-layer Cauchy; mlp-* baselines; mlp-cauchy-head=minimal-change baseline")
    p.add_argument("--widths", type=int, nargs="+", default=[16,32,64,128])

    p.add_argument("--mlp-depth", type=int, default=1, help="single MLP depth if --mlp-depths not provided")
    p.add_argument("--mlp-depths", type=int, nargs="+", default=[1,3,6],
                   help="candidate MLP depths; best-of-depth envelope will be plotted")
    p.add_argument("--param-matched", action="store_true",
                   help="match each non-XNet arch (for each depth) to XNet params at each anchor width")
    p.add_argument("--seeds", type=int, nargs="+", default=[0,1])

    # training
    p.add_argument("--iters", type=int, default=4000)
    p.add_argument("--lr", type=float, default=2e-3)
    p.add_argument("--wd", type=float, default=1e-6)
    p.add_argument("--eval-every", type=int, default=400)
    p.add_argument("--time-budget", type=float, default=18.0)

    # XNet/CH knobs
    p.add_argument("--d0", type=float, default=0.5)
    p.add_argument("--slope-cap", type=float, default=0.4)
    p.add_argument("--slope-weight", type=float, default=1e-2)

    # tasks
    p.add_argument("--alpha", type=float, default=0.5, help="only for abs_alpha")
    p.add_argument("--rate-plots", action="store_true")
    p.add_argument("--target-l2", type=float, default=None, help="threshold L2 for time-to-accuracy table; default depends on task")

    p.add_argument("--anneal-d", action="store_true",
                help="shrink Cauchy d late in training (off by default)")
    p.add_argument("--anneal-start-frac", type=float, default=0.8)
    p.add_argument("--anneal-factor", type=float, default=0.995)

    # eng
    p.add_argument("--amp", action="store_true")
    return p

if __name__ == "__main__":
    args = build_argparser().parse_args()
    run_suite(args)
