{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56e99126",
   "metadata": {},
   "outputs": [],
   "source": [
    "# np_np_gauss_or_climsim.py\n",
    "# ------------------------------------------------------------\n",
    "# Neyman–Pearson (Target + Source) with decision rule\n",
    "# Two data modes:\n",
    "#   1) Multidimensional Gaussians (stable, no standardization)\n",
    "#   2) ClimSim loader (optional; requires climsim_data_import)\n",
    "#\n",
    "# One architecture switch used EVERYWHERE:\n",
    "#   MODEL_KIND in main(): \"linear\"  or  \"mlp\" (ReLU, 16,16)\n",
    "#\n",
    "# Baselines (same routine; only data/ε differ):\n",
    "#   (A) θ*_T (target-only)\n",
    "#   (B) θ*_S (source-only)\n",
    "#   (C) θ*_Agg (T+S aggregated)\n",
    "#\n",
    "# NEW: Two sweep utilities (no algorithm changes):\n",
    "#   • sweep_target_samples_vs_errors(...) : vary n_{0,T}=n_{1,T}, plot Type-I/II (4 curves)\n",
    "#   • sweep_source_samples_vs_errors(...) : vary n_{0,S}=n_{1,S}, plot Type-I/II (4 curves)\n",
    "# Each point is averaged over N trials (default 10).\n",
    "# ------------------------------------------------------------\n",
    "\n",
    "import os\n",
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "import matplotlib.pyplot as plt  # <— for plots\n",
    "\n",
    "torch.manual_seed(0)\n",
    "\n",
    "# ============== Utils ==============\n",
    "\n",
    "def eps_from_n(n, c=1.0):\n",
    "    return float(c) / math.sqrt(float(max(1, n)))\n",
    "\n",
    "def make_gaussian(n, d, mean, var, generator=None):\n",
    "    g = generator\n",
    "    x = torch.randn(n, d, generator=g) if g is not None else torch.randn(n, d)\n",
    "    return x * math.sqrt(var) + mean * torch.ones(1, d)\n",
    "\n",
    "def make_pair_sep(n0, n1, d, t0, c0, t1, c1, seed=None):\n",
    "    g = torch.Generator().manual_seed(int(seed)) if seed is not None else None\n",
    "    x0 = make_gaussian(n0, d, t0, c0, g)\n",
    "    x1 = make_gaussian(n1, d, t1, c1, g)\n",
    "    return x0, x1\n",
    "\n",
    "def phi_pos(z, k):  # for R0\n",
    "    return torch.sigmoid(k * z)\n",
    "\n",
    "def phi_neg(z, k):  # for R1\n",
    "    return torch.sigmoid(-k * z)\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_target_test(model, k, x0t, x1t, title=None):\n",
    "    h0 = model(x0t.float()); h1 = model(x1t.float())\n",
    "    R0T_sur = phi_pos(h0, k).mean().item()\n",
    "    R1T_sur = phi_neg(h1, k).mean().item()\n",
    "    typeI   = (h0 > 0).float().mean().item()\n",
    "    typeII  = (h1 <= 0).float().mean().item()\n",
    "    return {\"R0T_sur\": R0T_sur, \"R1T_sur\": R1T_sur, \"typeI\": typeI, \"typeII\": typeII}\n",
    "\n",
    "def _assert_trainable(model):\n",
    "    if not any(p.requires_grad for p in model.parameters()):\n",
    "        raise RuntimeError(\"Model has no trainable parameters (requires_grad=False).\")\n",
    "\n",
    "# ============== Models & Factories ==============\n",
    "\n",
    "class MLP2(nn.Module):\n",
    "    def __init__(self, d, h1=16, h2=16):\n",
    "        super().__init__()\n",
    "        self.l1  = nn.Linear(d, h1, bias=True)\n",
    "        self.l2  = nn.Linear(h1, h2, bias=True)\n",
    "        self.out = nn.Linear(h2, 1,  bias=True)\n",
    "        self.act = nn.ReLU()\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        nn.init.kaiming_normal_(self.l1.weight, nonlinearity='relu'); nn.init.zeros_(self.l1.bias)\n",
    "        nn.init.kaiming_normal_(self.l2.weight, nonlinearity='relu'); nn.init.zeros_(self.l2.bias)\n",
    "        nn.init.xavier_uniform_(self.out.weight, gain=0.1);          nn.init.zeros_(self.out.bias)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.float()\n",
    "        x = self.act(self.l1(x))\n",
    "        x = self.act(self.l2(x))\n",
    "        return self.out(x)\n",
    "\n",
    "def make_linear_factory(d):\n",
    "    def factory():\n",
    "        m = nn.Linear(d, 1, bias=True)\n",
    "        with torch.no_grad():\n",
    "            nn.init.xavier_uniform_(m.weight, gain=0.1)\n",
    "            nn.init.zeros_(m.bias)\n",
    "        return m\n",
    "    return factory\n",
    "\n",
    "def make_mlp_factory(d, hidden=(16,16)):\n",
    "    h1, h2 = hidden\n",
    "    def factory():\n",
    "        return MLP2(d, h1=h1, h2=h2)\n",
    "    return factory\n",
    "\n",
    "def make_model_factory_by_kind(kind, d):\n",
    "    kind = (kind or \"linear\").lower()\n",
    "    if kind == \"mlp\":\n",
    "        return make_mlp_factory(d, hidden=(16,16))\n",
    "    return make_linear_factory(d)\n",
    "\n",
    "def zero_init_last_layer(model):\n",
    "    with torch.no_grad():\n",
    "        if isinstance(model, nn.Linear):\n",
    "            model.weight.zero_(); model.bias.zero_()\n",
    "        elif hasattr(model, \"out\") and isinstance(model.out, nn.Linear):\n",
    "            model.out.weight.zero_(); model.out.bias.zero_()\n",
    "        else:\n",
    "            raise RuntimeError(\"Cannot locate final linear layer to zero-init.\")\n",
    "\n",
    "def set_train_last_layer_only(model, last_only=True):\n",
    "    if isinstance(model, nn.Linear):\n",
    "        for p in model.parameters(): p.requires_grad_(True)\n",
    "        return\n",
    "    if not hasattr(model, \"out\") or not isinstance(model.out, nn.Linear):\n",
    "        for p in model.parameters(): p.requires_grad_(True)\n",
    "        return\n",
    "    for name, p in model.named_parameters():\n",
    "        if last_only and not name.startswith(\"out.\"):\n",
    "            p.requires_grad_(False)\n",
    "        else:\n",
    "            p.requires_grad_(True)\n",
    "\n",
    "class ParamEMA:\n",
    "    def __init__(self, model):\n",
    "        self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}\n",
    "        self.t = 0\n",
    "    @torch.no_grad()\n",
    "    def update(self, model):\n",
    "        self.t += 1\n",
    "        for k, v in model.state_dict().items():\n",
    "            self.shadow[k] += (v.detach() - self.shadow[k]) / float(self.t)\n",
    "    @torch.no_grad()\n",
    "    def copy_to(self, model):\n",
    "        model.load_state_dict(self.shadow, strict=True)\n",
    "\n",
    "# ============== θ* on TARGET (tightened α − ε0T) ==============\n",
    "\n",
    "def pretrain_unconstrained(model, x0, x1, steps=200, lr=1e-2, weight_decay=1e-5):\n",
    "    model.train(); _assert_trainable(model)\n",
    "    x0 = x0.float(); x1 = x1.float()\n",
    "    y0 = -torch.ones(x0.size(0), 1, dtype=x0.dtype); y1 = torch.ones(x1.size(0), 1, dtype=x1.dtype)\n",
    "    X = torch.cat([x0, x1], dim=0); Y = torch.cat([y0, y1], dim=0)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
    "    for _ in range(steps):\n",
    "        opt.zero_grad(set_to_none=True)\n",
    "        with torch.enable_grad():\n",
    "            h = model(X)\n",
    "            loss = torch.log1p(torch.exp(-Y * h)).mean()\n",
    "            if not loss.requires_grad:\n",
    "                raise RuntimeError(\"Pretrain loss lost grad path (check global no_grad).\")\n",
    "            loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)\n",
    "        opt.step()\n",
    "\n",
    "def train_np_sigmoid(alpha, k, T, eta_theta, eta_lambda, delta,\n",
    "                     x0, x1, model_factory, lam_cap=5.0,\n",
    "                     pretrain_steps=200, wd=1e-5, rho=1.0):\n",
    "    \"\"\"\n",
    "    Augmented Lagrangian for θ*:\n",
    "      L_aug = R1_hat + lam * g + rho * relu(g)^2\n",
    "      g = R0_hat - alpha - delta\n",
    "    Uses EMA for stability and returns the EMA model.\n",
    "    \"\"\"\n",
    "    model = model_factory()\n",
    "    for p in model.parameters(): p.requires_grad_(True)\n",
    "    model.train(); _assert_trainable(model)\n",
    "    x0 = x0.float(); x1 = x1.float()\n",
    "\n",
    "    pretrain_unconstrained(model, x0, x1, steps=pretrain_steps, lr=1e-2, weight_decay=wd)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=eta_theta, weight_decay=wd)\n",
    "    lam = 0.0\n",
    "    ema = ParamEMA(model)\n",
    "\n",
    "    for _ in range(T):\n",
    "        opt.zero_grad(set_to_none=True)\n",
    "        with torch.enable_grad():\n",
    "            h0 = model(x0); h1 = model(x1)\n",
    "            R0_hat = phi_pos(h0, k).mean()\n",
    "            R1_hat = phi_neg(h1, k).mean()\n",
    "            g_val  = R0_hat - alpha - delta\n",
    "            penalty = torch.relu(g_val)**2\n",
    "            L = R1_hat + lam * g_val + rho * penalty\n",
    "            if not L.requires_grad:\n",
    "                raise RuntimeError(\"NP loss lost grad path.\")\n",
    "            L.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)\n",
    "        opt.step()\n",
    "        with torch.no_grad():\n",
    "            ema.update(model)\n",
    "            lam = max(lam + eta_lambda * g_val.item(), 0.0)\n",
    "            if lam_cap is not None: lam = min(lam, float(lam_cap))\n",
    "\n",
    "    model_avg = model_factory()\n",
    "    ema.copy_to(model_avg)\n",
    "    with torch.no_grad():\n",
    "        h0 = model_avg(x0); h1 = model_avg(x1)\n",
    "        R0_tr = phi_pos(h0, k).mean().item()\n",
    "        R1_tr = phi_neg(h1, k).mean().item()\n",
    "        typeI_tr = (h0 > 0).float().mean().item()\n",
    "        typeII_tr = (h1 <= 0).float().mean().item()\n",
    "    return model_avg, R0_tr, R1_tr, typeI_tr, typeII_tr\n",
    "\n",
    "def solve_theta_star_T_tight(x0T, x1T, alpha, eps0T, k, T_inner,\n",
    "                             eta_theta, eta_lambda, delta, model_factory, rho=1.0):\n",
    "    alpha_tight = max(alpha - eps0T, 0.0)\n",
    "    model, _, R1_star, *_ = train_np_sigmoid(\n",
    "        alpha=alpha_tight, k=k, T=T_inner,\n",
    "        eta_theta=eta_theta, eta_lambda=eta_lambda, delta=delta,\n",
    "        x0=x0T, x1=x1T, model_factory=model_factory, lam_cap=5.0,\n",
    "        rho=rho\n",
    "    )\n",
    "    return model, float(R1_star)\n",
    "\n",
    "# ============== α′ optimization (same model kind with safe α′) ==============\n",
    "\n",
    "def _alpha_from_u(u, alpha, tau=1e-6):\n",
    "    return float(alpha) + (1.0 - float(alpha) - tau) * torch.sigmoid(torch.tensor(u)).item()\n",
    "\n",
    "def optimize_alpha_prime(\n",
    "    alpha,\n",
    "    x0T, x1T, x0S,\n",
    "    model_kind,\n",
    "    k=20.0,\n",
    "    c_eps0T=0.5, c_eps0S=0.5, c_eps1T=1.0,\n",
    "    T_outer=4000,\n",
    "    eta_theta=0.02,   # small θ-step in α′ loop\n",
    "    eta_alpha=0.05,   # step on u (the α′ logits)\n",
    "    eta_lambda1=1.0, eta_lambda2=2.0, eta_lambda3=1.0,\n",
    "    delta=0.0, T_inner_np=3000,\n",
    "    alpha_loop_train_last_layer_only=True,\n",
    "    alpha_loop_weight_decay=1e-4\n",
    "):\n",
    "    n0T, d = x0T.shape; n1T, _ = x1T.shape; n0S, _ = x0S.shape\n",
    "    eps0T = eps_from_n(n0T, c_eps0T)\n",
    "    eps0S = eps_from_n(n0S, c_eps0S)\n",
    "    eps1T = eps_from_n(n1T, c_eps1T)\n",
    "\n",
    "    model_factory = make_model_factory_by_kind(model_kind, d)\n",
    "    rho_theta = 2.0 if model_kind.lower() == \"mlp\" else 1.0\n",
    "    eta_theta_theta = 0.02 if model_kind.lower() == \"mlp\" else 0.1\n",
    "\n",
    "    theta_star_model, R1_star = solve_theta_star_T_tight(\n",
    "        x0T, x1T, alpha, eps0T, k, T_inner_np,\n",
    "        eta_theta=eta_theta_theta, eta_lambda=eta_lambda1, delta=delta,\n",
    "        model_factory=model_factory, rho=rho_theta\n",
    "    )\n",
    "\n",
    "    model_alpha = model_factory()\n",
    "    zero_init_last_layer(model_alpha)\n",
    "    set_train_last_layer_only(model_alpha, last_only=alpha_loop_train_last_layer_only)\n",
    "    model_alpha.train(); _assert_trainable(model_alpha)\n",
    "\n",
    "    opt_theta = torch.optim.SGD(filter(lambda p: p.requires_grad, model_alpha.parameters()),\n",
    "                                lr=eta_theta, weight_decay=alpha_loop_weight_decay)\n",
    "\n",
    "    def R0_T_hat(): return phi_pos(model_alpha(x0T.float()), k).mean()\n",
    "    def R1_T_hat(): return phi_neg(model_alpha(x1T.float()), k).mean()\n",
    "    def R0_S_hat(): return phi_pos(model_alpha(x0S.float()), k).mean()\n",
    "\n",
    "    u = 0.0\n",
    "    alpha_prime = _alpha_from_u(u, alpha)\n",
    "\n",
    "    lam1 = lam2 = lam3 = 0.0\n",
    "    alpha_prime_avg = 0.0\n",
    "    alpha_prime_last = alpha_prime\n",
    "\n",
    "    with torch.no_grad():\n",
    "        g2_0 = (R0_S_hat() - alpha_prime - eps0S).item()\n",
    "        if g2_0 > 0:\n",
    "            lam2 = min(1.5, 5.0)\n",
    "        lam1 = 0.0\n",
    "        lam3 = 0.0\n",
    "\n",
    "    for t in range(1, T_outer + 1):\n",
    "        opt_theta.zero_grad(set_to_none=True)\n",
    "        with torch.enable_grad():\n",
    "            g1 = R0_T_hat() - alpha - eps0T - delta\n",
    "            g2 = R0_S_hat() - alpha_prime - eps0S\n",
    "            g3 = R1_T_hat() - R1_star - eps1T\n",
    "            L  = alpha_prime + lam1*g1 + lam2*g2 + lam3*g3\n",
    "            if not L.requires_grad:\n",
    "                raise RuntimeError(\"Alpha-prime objective lost grad path.\")\n",
    "            L.backward()\n",
    "        opt_theta.step()\n",
    "\n",
    "        with torch.no_grad():\n",
    "            lam1 = min(max(lam1 + eta_lambda1 * g1.item(), 0.0), 5.0)\n",
    "            lam2 = min(max(lam2 + eta_lambda2 * g2.item(), 0.0), 5.0)\n",
    "            lam3 = min(max(lam3 + eta_lambda3 * g3.item(), 0.0), 5.0)\n",
    "\n",
    "            u += eta_alpha * (lam2 - 1.0)\n",
    "            alpha_prime = _alpha_from_u(u, alpha)\n",
    "\n",
    "            alpha_prime_last = alpha_prime\n",
    "            alpha_prime_avg += (alpha_prime - alpha_prime_avg) / t\n",
    "\n",
    "    alpha_prime_avg = max(alpha_prime_avg, float(alpha))\n",
    "    alpha_prime_last = max(alpha_prime_last, float(alpha))\n",
    "\n",
    "    return {\n",
    "        \"alpha_prime\": alpha_prime_avg,\n",
    "        \"alpha_prime_last\": alpha_prime_last,\n",
    "        \"eps0T\": eps0T, \"eps0S\": eps0S, \"eps1T\": eps1T,\n",
    "        \"R1_star_train\": R1_star,\n",
    "        \"theta_star_model\": theta_star_model,\n",
    "        \"train\": {\n",
    "            \"x0T\": x0T, \"x1T\": x1T, \"x0S\": x0S,\n",
    "            \"alpha\": alpha, \"k\": k, \"delta\": delta, \"d\": d, \"model_kind\": model_kind\n",
    "        }\n",
    "    }\n",
    "\n",
    "# ============== Stage 2: min R1_T s.t. g1,g2,g3 ==============\n",
    "\n",
    "def solve_min_R1_given_alpha_prime(alpha_prime_hat, meta, T=2000,\n",
    "                                   eta_theta=0.01,\n",
    "                                   eta_lambda1=5.0, eta_lambda2=5.0, eta_lambda3=5.0,\n",
    "                                   wd=1e-5):\n",
    "    x0T = meta[\"train\"][\"x0T\"]; x1T = meta[\"train\"][\"x1T\"]; x0S = meta[\"train\"][\"x0S\"]\n",
    "    alpha = meta[\"train\"][\"alpha\"]; k = meta[\"train\"][\"k\"]; delta = meta[\"train\"][\"delta\"]\n",
    "    eps0T = meta[\"eps0T\"]; eps0S = meta[\"eps0S\"]; eps1T = meta[\"eps1T\"]\n",
    "    d = meta[\"train\"][\"d\"]; model_kind = meta[\"train\"][\"model_kind\"]\n",
    "\n",
    "    model_factory = make_model_factory_by_kind(model_kind, d)\n",
    "    model = model_factory()\n",
    "    for p in model.parameters(): p.requires_grad_(True)\n",
    "    model.train(); _assert_trainable(model)\n",
    "\n",
    "    opt_theta = torch.optim.Adam(model.parameters(), lr=eta_theta, weight_decay=wd)\n",
    "    lam1 = lam2 = lam3 = 0.0\n",
    "    ema = ParamEMA(model)\n",
    "\n",
    "    for _ in range(1, T + 1):\n",
    "        opt_theta.zero_grad(set_to_none=True)\n",
    "        with torch.enable_grad():\n",
    "            R0T = phi_pos(model(x0T.float()), k).mean()\n",
    "            R0S = phi_pos(model(x0S.float()), k).mean()\n",
    "            R1T = phi_neg(model(x1T.float()), k).mean()\n",
    "            g1 = R0T - alpha - eps0T - delta\n",
    "            g2 = R0S - alpha_prime_hat - eps0S\n",
    "            g3 = R1T - meta[\"R1_star_train\"] - eps1T\n",
    "            L = R1T + lam1*g1 + lam2*g2 + lam3*g3\n",
    "            if not L.requires_grad:\n",
    "                raise RuntimeError(\"Stage2 objective lost grad path.\")\n",
    "            L.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)\n",
    "        opt_theta.step()\n",
    "        with torch.no_grad():\n",
    "            ema.update(model)\n",
    "            lam1 = min(max(lam1 + eta_lambda1 * g1.item(), 0.0), 5.0)\n",
    "            lam2 = min(max(lam2 + eta_lambda2 * g2.item(), 0.0), 5.0)\n",
    "            lam3 = min(max(lam3 + eta_lambda3 * g3.item(), 0.0), 5.0)\n",
    "\n",
    "    theta_avg_model = model_factory()\n",
    "    ema.copy_to(theta_avg_model)\n",
    "    with torch.no_grad():\n",
    "        R1_train_val = phi_neg(theta_avg_model(x1T.float()), k).mean().item()\n",
    "    return theta_avg_model, R1_train_val\n",
    "\n",
    "# ============== Stage 3: Final min R1_S with 4 constraints ==============\n",
    "\n",
    "def solve_min_R1S_final(meta_alpha, R1T_star_Halph, x1S, T=3000,\n",
    "                        eta_theta=0.01, eta_lambda=5.0, lam_cap=5.0, wd=1e-5):\n",
    "    k = meta_alpha[\"train\"][\"k\"]\n",
    "    x0T = meta_alpha[\"train\"][\"x0T\"]; x1T = meta_alpha[\"train\"][\"x1T\"]; x0S = meta_alpha[\"train\"][\"x0S\"]\n",
    "    alpha = meta_alpha[\"train\"][\"alpha\"]; delta = meta_alpha[\"train\"][\"delta\"]\n",
    "    eps0T = meta_alpha[\"eps0T\"]; eps0S = meta_alpha[\"eps0S\"]; eps1T = meta_alpha[\"eps1T\"]\n",
    "    d = meta_alpha[\"train\"][\"d\"]; model_kind = meta_alpha[\"train\"][\"model_kind\"]\n",
    "    alpha_prime_hat = float(meta_alpha[\"alpha_prime\"])\n",
    "    R1T_baseline_tight = float(meta_alpha[\"R1_star_train\"])\n",
    "\n",
    "    model_factory = make_model_factory_by_kind(model_kind, d)\n",
    "    model = model_factory()\n",
    "    for p in model.parameters(): p.requires_grad_(True)\n",
    "    model.train(); _assert_trainable(model)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=eta_theta, weight_decay=wd)\n",
    "    lam = 0.0\n",
    "    ema = ParamEMA(model)\n",
    "\n",
    "    def g_max(m):\n",
    "        R0T = phi_pos(m(x0T.float()), k).mean()\n",
    "        R0S = phi_pos(m(x0S.float()), k).mean()\n",
    "        R1T = phi_neg(m(x1T.float()), k).mean()\n",
    "        g1 = R0T - (alpha + eps0T + delta)\n",
    "        g2 = R0S - (alpha_prime_hat + eps0S)\n",
    "        g3 = R1T - R1T_baseline_tight - eps1T\n",
    "        g4 = R1T - R1T_star_Halph   - eps1T\n",
    "        return torch.stack([g1, g2, g3, g4]).max()\n",
    "\n",
    "    def R1S(m): return phi_neg(m(x1S.float()), k).mean()\n",
    "\n",
    "    for _ in range(1, T + 1):\n",
    "        opt.zero_grad(set_to_none=True)\n",
    "        with torch.enable_grad():\n",
    "            obj  = R1S(model)\n",
    "            gval = g_max(model)\n",
    "            L = obj + lam * gval\n",
    "            if not L.requires_grad:\n",
    "                raise RuntimeError(\"Final-stage objective lost grad path.\")\n",
    "            L.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)\n",
    "        opt.step()\n",
    "        with torch.no_grad():\n",
    "            ema.update(model)\n",
    "            lam = max(lam + eta_lambda * gval.item(), 0.0)\n",
    "            lam = min(lam, lam_cap)\n",
    "\n",
    "    theta_hat = model_factory()\n",
    "    ema.copy_to(theta_hat)\n",
    "    with torch.no_grad():\n",
    "        R1S_hat = R1S(theta_hat).item()\n",
    "    return theta_hat, R1S_hat\n",
    "\n",
    "# ============== Baseline helper (θ* over arbitrary dataset) ==============\n",
    "\n",
    "def train_theta_star_on_dataset(alpha, k, T_inner, eta_theta, eta_lambda, delta,\n",
    "                                x0, x1, c_eps0, model_factory, rho=1.0):\n",
    "    n0 = x0.shape[0]\n",
    "    eps0 = eps_from_n(n0, c_eps0)\n",
    "    alpha_tight = max(alpha - eps0, 0.0)\n",
    "    model, _, _, _, _ = train_np_sigmoid(\n",
    "        alpha=alpha_tight, k=k, T=T_inner, eta_theta=eta_theta, eta_lambda=eta_lambda,\n",
    "        delta=delta, x0=x0, x1=x1, model_factory=model_factory, lam_cap=5.0, rho=rho\n",
    "    )\n",
    "    return model, eps0\n",
    "\n",
    "# ============== Pretty print (final-only) ==============\n",
    "\n",
    "def print_final_metrics(alpha_prime, metrics_dict):\n",
    "    print(f\"alpha' = {alpha_prime:.4f}\")\n",
    "    print(\"\\n=== TARGET TEST (final) ===\")\n",
    "    order = [\n",
    "        (\"Algorithm (selected)\", \"alg\"),\n",
    "        (\"Baseline A: theta*_T (target-only)\", \"tgt\"),\n",
    "        (\"Baseline B: theta*_S (source-only)\", \"src\"),\n",
    "        (\"Baseline C: theta*_Agg (T+S aggregated)\", \"agg\"),\n",
    "    ]\n",
    "    for label, key in order:\n",
    "        m = metrics_dict[key]\n",
    "        print(f\"{label}: R0_T={m['R0T_sur']:.4f}, R1_T={m['R1T_sur']:.4f}, \"\n",
    "              f\"Type-I={m['typeI']:.4f}, Type-II={m['typeII']:.4f}\")\n",
    "\n",
    "def acc_from_metrics(m):\n",
    "    return 1.0 - 0.5 * (m[\"typeI\"] + m[\"typeII\"])\n",
    "\n",
    "# ============== Optional: ClimSim loader ==============\n",
    "\n",
    "def resolve_climsim_root():\n",
    "    env = os.getenv(\"CLIMSIM_ROOT\", None)\n",
    "    if env and os.path.isdir(env): return env\n",
    "    try:\n",
    "        here = os.path.dirname(os.path.abspath(__file__))\n",
    "    except NameError:\n",
    "        here = os.getcwd()\n",
    "    for cand in [os.path.join(here, \"data\", \"climsim_data\"),\n",
    "                 os.path.join(os.path.dirname(here), \"data\", \"climsim_data\"),\n",
    "                 os.path.join(os.getcwd(), \"data\", \"climsim_data\")]:\n",
    "        if os.path.isdir(cand): return cand\n",
    "    return None\n",
    "\n",
    "def try_import_climsim():\n",
    "    try:\n",
    "        from climsim_data_import import load_climsim_data, prepare_climsim_data\n",
    "        return load_climsim_data, prepare_climsim_data\n",
    "    except Exception as e:\n",
    "        print(f\"[warn] ClimSim import failed: {e}\")\n",
    "        return None, None\n",
    "\n",
    "# ============== NEW: Experiment runner (single instance) ==============\n",
    "\n",
    "@torch.no_grad()\n",
    "def run_alg_and_baselines_once(\n",
    "    x0T, x1T, x0S, x1S, x0T_test, x1T_test,\n",
    "    *,  # keyword-only for clarity\n",
    "    model_kind=\"mlp\",\n",
    "    alpha=0.10, k=20.0,\n",
    "    c_eps0T=0.5, c_eps0S=0.5, c_eps1T=1.0, c_eps1S=1.0,\n",
    "    T_inner_np=1500, T_alpha=4000, T_stage2=2000, T_final=3000\n",
    "):\n",
    "    d = x0T.shape[1]\n",
    "    model_factory = make_model_factory_by_kind(model_kind, d)\n",
    "    rho_baseline = 2.0 if model_kind.lower() == \"mlp\" else 1.0\n",
    "    eta_theta_star = 0.02 if model_kind.lower() == \"mlp\" else 0.1\n",
    "\n",
    "    # 1) α′ and θ* (tightened α−ε0T)\n",
    "    meta = optimize_alpha_prime(\n",
    "        alpha=alpha, x0T=x0T, x1T=x1T, x0S=x0S, model_kind=model_kind, k=k,\n",
    "        c_eps0T=c_eps0T, c_eps0S=c_eps0S, c_eps1T=c_eps1T,\n",
    "        T_outer=T_alpha, eta_theta=0.02, eta_alpha=0.05,\n",
    "        eta_lambda1=1.0, eta_lambda2=2.0, eta_lambda3=1.0,\n",
    "        delta=0.0, T_inner_np=T_inner_np,\n",
    "        alpha_loop_train_last_layer_only=True, alpha_loop_weight_decay=1e-4\n",
    "    )\n",
    "\n",
    "    # 2) θ_{α′,T}\n",
    "    theta_alphaT, R1T_star_Halph = solve_min_R1_given_alpha_prime(\n",
    "        alpha_prime_hat=meta[\"alpha_prime\"], meta=meta,\n",
    "        T=T_stage2, eta_theta=0.01,\n",
    "        eta_lambda1=5.0, eta_lambda2=5.0, eta_lambda3=5.0\n",
    "    )\n",
    "\n",
    "    # 3) Proxy R1S* on H(α′)\n",
    "    R1S_star_proxy = phi_neg(theta_alphaT(x1S.float()), meta[\"train\"][\"k\"]).mean().item()\n",
    "\n",
    "    # 4) Final θ̂\n",
    "    theta_hat, R1S_hat_theta = solve_min_R1S_final(\n",
    "        meta_alpha=meta, R1T_star_Halph=R1T_star_Halph,\n",
    "        x1S=x1S, T=T_final, eta_theta=0.01, eta_lambda=5.0\n",
    "    )\n",
    "\n",
    "    # Decision\n",
    "    eps1S = eps_from_n(x1S.shape[0], c_eps1S)\n",
    "    use_alphaT = (R1S_hat_theta - R1S_star_proxy > eps1S)\n",
    "    chosen_model = theta_alphaT if use_alphaT else theta_hat\n",
    "\n",
    "    # Baselines\n",
    "    theta_star_T = meta[\"theta_star_model\"]\n",
    "    theta_star_S, _ = train_theta_star_on_dataset(\n",
    "        alpha=alpha, k=k, T_inner=T_inner_np,\n",
    "        eta_theta=eta_theta_star, eta_lambda=1.0, delta=0.0,\n",
    "        x0=x0S, x1=x1S, c_eps0=c_eps0S, model_factory=model_factory, rho=rho_baseline\n",
    "    )\n",
    "    x0Agg = torch.cat([x0T, x0S], dim=0)\n",
    "    x1Agg = torch.cat([x1T, x1S], dim=0)\n",
    "    theta_star_Agg, _ = train_theta_star_on_dataset(\n",
    "        alpha=alpha, k=k, T_inner=T_inner_np,\n",
    "        eta_theta=eta_theta_star, eta_lambda=1.0, delta=0.0,\n",
    "        x0=x0Agg, x1=x1Agg, c_eps0=c_eps0T, model_factory=model_factory, rho=rho_baseline\n",
    "    )\n",
    "\n",
    "    # Evaluations on TARGET TEST\n",
    "    metrics = {}\n",
    "    metrics[\"ours\"] = evaluate_target_test(chosen_model, k, x0T_test, x1T_test)\n",
    "    metrics[\"tgt\"]  = evaluate_target_test(theta_star_T, k, x0T_test, x1T_test)\n",
    "    metrics[\"src\"]  = evaluate_target_test(theta_star_S, k, x0T_test, x1T_test)\n",
    "    metrics[\"agg\"]  = evaluate_target_test(theta_star_Agg, k, x0T_test, x1T_test)\n",
    "    return metrics\n",
    "\n",
    "# ============== NEW: Data generators for sweeps ==============\n",
    "\n",
    "def gen_gaussian_data(n0T, n1T, n0S, n1S, Ntest,\n",
    "                      d=1,\n",
    "                      t0T=-10.0, c0T=1.0, t1T=1.0, c1T=1.0,\n",
    "                      t0S=0.0,  c0S=1.0, t1S=1.0,  c1S=1.0,\n",
    "                      seed_base=0):\n",
    "    # Distinct seeds so trials don't share noise; deterministic per call\n",
    "    x0T, x1T = make_pair_sep(n0T, n1T, d, t0T, c0T, t1T, c1T, seed=seed_base + 17)\n",
    "    x0S, x1S = make_pair_sep(n0S, n1S, d, t0S, c0S, t1S, c1S, seed=seed_base + 31)\n",
    "    x0T_test, x1T_test = make_pair_sep(Ntest, Ntest, d, t0T, c0T, t1T, c1T, seed=seed_base + 59)\n",
    "    return x0T, x1T, x0S, x1S, x0T_test, x1T_test\n",
    "\n",
    "def gen_climsim_data(n0T, n1T, n0S, n1S, Ntest, d, seed_base=0):\n",
    "    load_climsim_data, prepare_climsim_data = try_import_climsim()\n",
    "    if load_climsim_data is None:\n",
    "        raise RuntimeError(\"ClimSim unavailable. Set use_gaussian=True for sweeps.\")\n",
    "    root = resolve_climsim_root()\n",
    "    cfg = {\n",
    "        \"data_frequency\": \"daily\",\n",
    "        \"data_mode\": \"cluster_4\",\n",
    "        \"targets\": [26], \"sources\": [27],\n",
    "        \"num_target_normal_training\":  n0T,\n",
    "        \"num_target_abnormal_training\": n1T,\n",
    "        \"num_source_normal\":           n0S,\n",
    "        \"num_source_abnormal\":         n1S,\n",
    "        \"num_target_normal_test\":  Ntest,\n",
    "        \"num_target_abnormal_test\": Ntest,\n",
    "        \"input_dim\": d\n",
    "    }\n",
    "    data = load_climsim_data(root_path=root)\n",
    "    x0T, x1T, x0S, x1S, x0T_test, x1T_test = prepare_climsim_data(\n",
    "        climsim_data=data, config=cfg, seed=seed_base\n",
    "    )\n",
    "    with torch.no_grad():\n",
    "        mean = torch.cat([x0T.float(), x1T.float(), x0S.float(), x1S.float()], dim=0).mean(0, keepdim=True)\n",
    "        std  = torch.cat([x0T.float(), x1T.float(), x0S.float(), x1S.float()], dim=0).std(0, unbiased=False, keepdim=True).clamp_min(1e-6)\n",
    "        x0T = (x0T.float() - mean) / std\n",
    "        x1T = (x1T.float() - mean) / std\n",
    "        x0S = (x0S.float() - mean) / std\n",
    "        x1S = (x1S.float() - mean) / std\n",
    "        x0T_test = (x0T_test.float() - mean) / std\n",
    "        x1T_test = (x1T_test.float() - mean) / std\n",
    "    return x0T, x1T, x0S, x1S, x0T_test, x1T_test\n",
    "\n",
    "# ============== NEW: Sweep functions (plots) ==============\n",
    "\n",
    "def sweep_target_samples_vs_errors(\n",
    "    nT_values,\n",
    "    *,\n",
    "    n0S=1000, n1S=1000,\n",
    "    trials=10,\n",
    "    use_gaussian=True,\n",
    "    gaussian_kwargs=None,  # dict of {d, t0T, c0T, ...}, optional\n",
    "    climsim_d=124,\n",
    "    Ntest=1700,\n",
    "    model_kind=\"mlp\",\n",
    "    alpha=0.10, k=20.0,\n",
    "    c_eps0T=0.5, c_eps0S=0.5, c_eps1T=1.0, c_eps1S=1.0,\n",
    "    T_inner_np=1500, T_alpha=4000, T_stage2=2000, T_final=3000,\n",
    "    base_seed=0\n",
    "):\n",
    "    \"\"\"\n",
    "    Vary n_{0,T}=n_{1,T} over nT_values, average Type-I/II across 'trials', and plot.\n",
    "    Returns (fig_typeI, fig_typeII) and the aggregated results dict.\n",
    "    \"\"\"\n",
    "    gaussian_kwargs = gaussian_kwargs or {}\n",
    "    # accumulators: for each method -> list aligned with nT_values\n",
    "    sums_typeI = {k: [0.0]*len(nT_values) for k in [\"ours\",\"tgt\",\"src\",\"agg\"]}\n",
    "    sums_typeII= {k: [0.0]*len(nT_values) for k in [\"ours\",\"tgt\",\"src\",\"agg\"]}\n",
    "\n",
    "    for trial in range(trials):\n",
    "        for i, nT in enumerate(nT_values):\n",
    "            seed = base_seed + 1000*trial + 7*i\n",
    "            if use_gaussian:\n",
    "                x0T,x1T,x0S,x1S,x0T_test,x1T_test = gen_gaussian_data(\n",
    "                    n0T=nT, n1T=nT, n0S=n0S, n1S=n1S, Ntest=Ntest, seed_base=seed, **gaussian_kwargs\n",
    "                )\n",
    "            else:\n",
    "                x0T,x1T,x0S,x1S,x0T_test,x1T_test = gen_climsim_data(\n",
    "                    n0T=nT, n1T=nT, n0S=n0S, n1S=n1S, Ntest=Ntest, d=climsim_d, seed_base=seed\n",
    "                )\n",
    "            metrics = run_alg_and_baselines_once(\n",
    "                x0T,x1T,x0S,x1S,x0T_test,x1T_test,\n",
    "                model_kind=model_kind, alpha=alpha, k=k,\n",
    "                c_eps0T=c_eps0T, c_eps0S=c_eps0S, c_eps1T=c_eps1T, c_eps1S=c_eps1S,\n",
    "                T_inner_np=T_inner_np, T_alpha=T_alpha, T_stage2=T_stage2, T_final=T_final\n",
    "            )\n",
    "            for key in [\"ours\",\"tgt\",\"src\",\"agg\"]:\n",
    "                sums_typeI[key][i]  += metrics[key][\"typeI\"]\n",
    "                sums_typeII[key][i] += metrics[key][\"typeII\"]\n",
    "\n",
    "    # Averages\n",
    "    avg_typeI  = {k: [v/trials for v in sums_typeI[k]]  for k in sums_typeI}\n",
    "    avg_typeII = {k: [v/trials for v in sums_typeII[k]] for k in sums_typeII}\n",
    "\n",
    "    # --- Plot Type-I ---\n",
    "    fig1 = plt.figure()\n",
    "    for key, label in [(\"ours\",\"Our method\"), (\"tgt\",\"θ*_T\"), (\"src\",\"θ*_S\"), (\"agg\",\"θ*_Agg\")]:\n",
    "        plt.plot(nT_values, avg_typeI[key], label=label, marker='o')\n",
    "    plt.xlabel(\"Number of TARGET samples per class (n0T = n1T)\")\n",
    "    plt.ylabel(\"Type-I (TARGET test)\")\n",
    "    plt.title(\"Type-I vs Target Samples\")\n",
    "    plt.grid(True, alpha=0.3); plt.legend()\n",
    "\n",
    "    # --- Plot Type-II ---\n",
    "    fig2 = plt.figure()\n",
    "    for key, label in [(\"ours\",\"Our method\"), (\"tgt\",\"θ*_T\"), (\"src\",\"θ*_S\"), (\"agg\",\"θ*_Agg\")]:\n",
    "        plt.plot(nT_values, avg_typeII[key], label=label, marker='o')\n",
    "    plt.xlabel(\"Number of TARGET samples per class (n0T = n1T)\")\n",
    "    plt.ylabel(\"Type-II (TARGET test)\")\n",
    "    plt.title(\"Type-II vs Target Samples\")\n",
    "    plt.grid(True, alpha=0.3); plt.legend()\n",
    "\n",
    "    results = {\"avg_typeI\": avg_typeI, \"avg_typeII\": avg_typeII}\n",
    "    return fig1, fig2, results\n",
    "\n",
    "def sweep_source_samples_vs_errors(\n",
    "    nS_values,\n",
    "    *,\n",
    "    n0T=40, n1T=40,           # fixed target sizes\n",
    "    trials=10,\n",
    "    use_gaussian=True,\n",
    "    gaussian_kwargs=None,\n",
    "    climsim_d=124,\n",
    "    Ntest=1700,\n",
    "    model_kind=\"mlp\",\n",
    "    alpha=0.10, k=20.0,\n",
    "    c_eps0T=0.5, c_eps0S=0.5, c_eps1T=1.0, c_eps1S=1.0,\n",
    "    T_inner_np=1500, T_alpha=4000, T_stage2=2000, T_final=3000,\n",
    "    base_seed=0\n",
    "):\n",
    "    \"\"\"\n",
    "    Vary n_{0,S}=n_{1,S} over nS_values, average Type-I/II across 'trials', and plot.\n",
    "    Returns (fig_typeI, fig_typeII) and the aggregated results dict.\n",
    "    \"\"\"\n",
    "    gaussian_kwargs = gaussian_kwargs or {}\n",
    "    sums_typeI = {k: [0.0]*len(nS_values) for k in [\"ours\",\"tgt\",\"src\",\"agg\"]}\n",
    "    sums_typeII= {k: [0.0]*len(nS_values) for k in [\"ours\",\"tgt\",\"src\",\"agg\"]}\n",
    "\n",
    "    for trial in range(trials):\n",
    "        for i, nS in enumerate(nS_values):\n",
    "            seed = base_seed + 1000*trial + 11*i\n",
    "            if use_gaussian:\n",
    "                x0T,x1T,x0S,x1S,x0T_test,x1T_test = gen_gaussian_data(\n",
    "                    n0T=n0T, n1T=n1T, n0S=nS, n1S=nS, Ntest=Ntest, seed_base=seed, **gaussian_kwargs\n",
    "                )\n",
    "            else:\n",
    "                x0T,x1T,x0S,x1S,x0T_test,x1T_test = gen_climsim_data(\n",
    "                    n0T=n0T, n1T=n1T, n0S=nS, n1S=nS, Ntest=Ntest, d=climsim_d, seed_base=seed\n",
    "                )\n",
    "            metrics = run_alg_and_baselines_once(\n",
    "                x0T,x1T,x0S,x1S,x0T_test,x1T_test,\n",
    "                model_kind=model_kind, alpha=alpha, k=k,\n",
    "                c_eps0T=c_eps0T, c_eps0S=c_eps0S, c_eps1T=c_eps1T, c_eps1S=c_eps1S,\n",
    "                T_inner_np=T_inner_np, T_alpha=T_alpha, T_stage2=T_stage2, T_final=T_final\n",
    "            )\n",
    "            for key in [\"ours\",\"tgt\",\"src\",\"agg\"]:\n",
    "                sums_typeI[key][i]  += metrics[key][\"typeI\"]\n",
    "                sums_typeII[key][i] += metrics[key][\"typeII\"]\n",
    "\n",
    "    avg_typeI  = {k: [v/trials for v in sums_typeI[k]]  for k in sums_typeI}\n",
    "    avg_typeII = {k: [v/trials for v in sums_typeII[k]] for k in sums_typeII}\n",
    "\n",
    "    # --- Plot Type-I ---\n",
    "    fig1 = plt.figure()\n",
    "    for key, label in [(\"ours\",\"Our method\"), (\"tgt\",\"θ*_T\"), (\"src\",\"θ*_S\"), (\"agg\",\"θ*_Agg\")]:\n",
    "        plt.plot(nS_values, avg_typeI[key], label=label, marker='o')\n",
    "    plt.xlabel(\"Number of SOURCE samples per class (n0S = n1S)\")\n",
    "    plt.ylabel(\"Type-I (TARGET test)\")\n",
    "    plt.title(\"Type-I vs Source Samples\")\n",
    "    plt.grid(True, alpha=0.3); plt.legend()\n",
    "\n",
    "    # --- Plot Type-II ---\n",
    "    fig2 = plt.figure()\n",
    "    for key, label in [(\"ours\",\"Our method\"), (\"tgt\",\"θ*_T\"), (\"src\",\"θ*_S\"), (\"agg\",\"θ*_Agg\")]:\n",
    "        plt.plot(nS_values, avg_typeII[key], label=label, marker='o')\n",
    "    plt.xlabel(\"Number of SOURCE samples per class (n0S = n1S)\")\n",
    "    plt.ylabel(\"Type-II (TARGET test)\")\n",
    "    plt.title(\"Type-II vs Source Samples\")\n",
    "    plt.grid(True, alpha=0.3); plt.legend()\n",
    "\n",
    "    results = {\"avg_typeI\": avg_typeI, \"avg_typeII\": avg_typeII}\n",
    "    return fig1, fig2, results\n",
    "\n",
    "# ============== Main (unchanged algorithm workflow + optional demo) ==============\n",
    "\n",
    "def main():\n",
    "    torch.manual_seed(0)\n",
    "\n",
    "    # ---- Choose data source (default: GAUSSIAN) ----\n",
    "    USE_GAUSSIAN = False\n",
    "\n",
    "    # ---- Choose model kind: \"linear\" or \"mlp\" ----\n",
    "    MODEL_KIND = \"mlp\"  # or \"linear\"\n",
    "\n",
    "    # ---- Common knobs ----\n",
    "    alpha = 0.10\n",
    "    k = 20.0\n",
    "    T_inner_np = 1500\n",
    "    T_alpha    = 4000\n",
    "    T_stage2   = 2000\n",
    "    T_final    = 3000\n",
    "\n",
    "    # eps scales\n",
    "    c_eps0T = 0.5\n",
    "    c_eps0S = 0.5\n",
    "    c_eps1T = 1.0\n",
    "    c_eps1S = 1.0\n",
    "\n",
    "    # sizes\n",
    "    n0T, n1T = 40, 40\n",
    "    n0S, n1S = 1000, 1000\n",
    "    Ntest = 1700\n",
    "\n",
    "    if USE_GAUSSIAN:\n",
    "        d = 1\n",
    "        t0T, c0T = -10, 1.0\n",
    "        t1T, c1T = 1.0, 1.0\n",
    "        t0S, c0S_ = 0.0, 1.0\n",
    "        t1S, c1S_ = 1.0, 1.0\n",
    "        x0T, x1T = make_pair_sep(n0T, n1T, d, t0T, c0T, t1T, c1T, seed=0)\n",
    "        x0S, x1S = make_pair_sep(n0S, n1S, d, t0S, c0S_, t1S, c1S_, seed=1)\n",
    "        x0T_test, x1T_test = make_pair_sep(Ntest, Ntest, d, t0T, c0T, t1T, c1T, seed=2)\n",
    "    else:\n",
    "        load_climsim_data, prepare_climsim_data = try_import_climsim()\n",
    "        if load_climsim_data is None:\n",
    "            raise RuntimeError(\"Set USE_GAUSSIAN=True or install/provide climsim_data_import.\")\n",
    "        d = 124\n",
    "        root = resolve_climsim_root()\n",
    "        cfg = {\n",
    "            \"data_frequency\": \"daily\",\n",
    "            \"data_mode\": \"cluster_4\",\n",
    "            \"targets\": [26], \"sources\": [27],\n",
    "            \"num_target_normal_training\":  n0T,\n",
    "            \"num_target_abnormal_training\": n1T,\n",
    "            \"num_source_normal\":           n0S,\n",
    "            \"num_source_abnormal\":         n1S,\n",
    "            \"num_target_normal_test\":  Ntest,\n",
    "            \"num_target_abnormal_test\": Ntest,\n",
    "            \"input_dim\": d\n",
    "        }\n",
    "        data = load_climsim_data(root_path=root)\n",
    "        x0T, x1T, x0S, x1S, x0T_test, x1T_test = prepare_climsim_data(\n",
    "            climsim_data=data, config=cfg, seed=0\n",
    "        )\n",
    "        with torch.no_grad():\n",
    "            mean = torch.cat([x0T.float(), x1T.float(), x0S.float(), x1S.float()], dim=0).mean(0, keepdim=True)\n",
    "            std  = torch.cat([x0T.float(), x1T.float(), x0S.float(), x1S.float()], dim=0).std(0, unbiased=False, keepdim=True).clamp_min(1e-6)\n",
    "            x0T = (x0T.float() - mean) / std\n",
    "            x1T = (x1T.float() - mean) / std\n",
    "            x0S = (x0S.float() - mean) / std\n",
    "            x1S = (x1S.float() - mean) / std\n",
    "            x0T_test = (x0T_test.float() - mean) / std\n",
    "            x1T_test = (x1T_test.float() - mean) / std\n",
    "\n",
    "    # --------- Run one standard pipeline + baselines and print concise summary ---------\n",
    "    metrics_once = run_alg_and_baselines_once(\n",
    "        x0T, x1T, x0S, x1S, x0T_test, x1T_test,\n",
    "        model_kind=MODEL_KIND, alpha=alpha, k=k,\n",
    "        c_eps0T=c_eps0T, c_eps0S=c_eps0S, c_eps1T=c_eps1T, c_eps1S=c_eps1S,\n",
    "        T_inner_np=T_inner_np, T_alpha=T_alpha, T_stage2=T_stage2, T_final=T_final\n",
    "    )\n",
    "    # Minimal one-run print (optional; comment out if you want completely silent main)\n",
    "    print(\"\\nOne-run (TARGET TEST) — Type-I / Type-II:\")\n",
    "    for key, label in [(\"ours\",\"Our\"), (\"tgt\",\"θ*_T\"), (\"src\",\"θ*_S\"), (\"agg\",\"θ*_Agg\")]:\n",
    "        m = metrics_once[key]\n",
    "        print(f\"{label}: Type-I={m['typeI']:.4f}, Type-II={m['typeII']:.4f}\")\n",
    "\n",
    "    # --------- OPTIONAL DEMO: run sweeps (disabled by default) ---------\n",
    "    RUN_SWEEPS = False\n",
    "    if RUN_SWEEPS:\n",
    "        # Target sweep\n",
    "        nT_values = list(range(30, 301, 30))\n",
    "        sweep_target_samples_vs_errors(\n",
    "            nT_values,\n",
    "            n0S=n0S, n1S=n1S, trials=10, use_gaussian=USE_GAUSSIAN, Ntest=Ntest,\n",
    "            model_kind=MODEL_KIND, alpha=alpha, k=k,\n",
    "            c_eps0T=c_eps0T, c_eps0S=c_eps0S, c_eps1T=c_eps1T, c_eps1S=c_eps1S,\n",
    "            T_inner_np=T_inner_np, T_alpha=T_alpha, T_stage2=T_stage2, T_final=T_final\n",
    "        )\n",
    "        # Source sweep\n",
    "        nS_values = list(range(50, 1001, 95))\n",
    "        sweep_source_samples_vs_errors(\n",
    "            nS_values,\n",
    "            n0T=n0T, n1T=n1T, trials=10, use_gaussian=USE_GAUSSIAN, Ntest=Ntest,\n",
    "            model_kind=MODEL_KIND, alpha=alpha, k=k,\n",
    "            c_eps0T=c_eps0T, c_eps0S=c_eps0S, c_eps1T=c_eps1T, c_eps1S=c_eps1S,\n",
    "            T_inner_np=T_inner_np, T_alpha=T_alpha, T_stage2=T_stage2, T_final=T_final\n",
    "        )\n",
    "        plt.show()\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    main()\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44e272dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "nS_values = list(range(50, 1001, 50))\n",
    "figI, figII, res = sweep_source_samples_vs_errors(\n",
    "    nS_values,\n",
    "    n0T=40, n1T=40, trials=10,\n",
    "    use_gaussian=False,\n",
    "    model_kind=\"mlp\", alpha=0.1, k=20.0\n",
    ")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
