{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6596ef2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# np_np_nasa_or_climsim_all.py\n",
    "# ------------------------------------------------------------\n",
    "# Neyman–Pearson (Target + Source) with decision rule\n",
    "# Data modes:\n",
    "#   (1) Multidimensional Gaussians (stable, no standardization)\n",
    "#   (2) ClimSim loader (optional; requires climsim_data_import)\n",
    "#   (3) NASA rain loader (CSV, input_dim=6 by default)\n",
    "#\n",
    "# Architecture switch used EVERYWHERE:\n",
    "#   \"linear\"  or  \"mlp\" (ReLU, 16,16)\n",
    "#\n",
    "# Baselines:\n",
    "#   (A) θ*_T (target-only)\n",
    "#   (B) θ*_S (source-only)\n",
    "#   (C) θ*_Agg (T+S aggregated)\n",
    "#\n",
    "# Sweeps (no algorithm changes):\n",
    "#   • sweep_target_samples_vs_errors(...) : vary n_{0,T}=n_{1,T}; per trial fix SOURCE + TARGET TEST\n",
    "#   • sweep_source_samples_vs_errors(...) : vary n_{0,S}=n_{1,S}; per trial fix TARGET TRAIN + TARGET TEST\n",
    "#     => θ*_S ~flat vs target-size sweep; θ*_T ~flat vs source-size sweep.\n",
    "#\n",
    "# Save helpers:\n",
    "#   save_sweep_results_csv/json, save_figure, plot_and_save_*_sweep\n",
    "# ------------------------------------------------------------\n",
    "\n",
    "import os\n",
    "import math\n",
    "import time\n",
    "import random\n",
    "import csv\n",
    "import json\n",
    "import torch\n",
    "import pandas as pd\n",
    "from torch import nn\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# ============== Utils ==============\n",
    "\n",
    "def eps_from_n(n, c=1.0):\n",
    "    return float(c) / math.sqrt(float(max(1, n)))\n",
    "\n",
    "def make_gaussian(n, d, mean, var, generator=None):\n",
    "    g = generator\n",
    "    x = torch.randn(n, d, generator=g) if g is not None else torch.randn(n, d)\n",
    "    return x * math.sqrt(var) + mean * torch.ones(1, d)\n",
    "\n",
    "def make_pair_sep(n0, n1, d, t0, c0, t1, c1, seed=None):\n",
    "    g = torch.Generator().manual_seed(int(seed)) if seed is not None else None\n",
    "    x0 = make_gaussian(n0, d, t0, c0, g)\n",
    "    x1 = make_gaussian(n1, d, t1, c1, g)\n",
    "    return x0, x1\n",
    "\n",
    "def phi_pos(z, k):  # for R0\n",
    "    return torch.sigmoid(k * z)\n",
    "\n",
    "def phi_neg(z, k):  # for R1\n",
    "    return torch.sigmoid(-k * z)\n",
    "\n",
    "@torch.no_grad()\n",
    "def evaluate_target_test(model, k, x0t, x1t, title=None):\n",
    "    h0 = model(x0t.float()); h1 = model(x1t.float())\n",
    "    R0T_sur = phi_pos(h0, k).mean().item()\n",
    "    R1T_sur = phi_neg(h1, k).mean().item()\n",
    "    typeI   = (h0 > 0).float().mean().item()\n",
    "    typeII  = (h1 <= 0).float().mean().item()\n",
    "    return {\"R0T_sur\": R0T_sur, \"R1T_sur\": R1T_sur, \"typeI\": typeI, \"typeII\": typeII}\n",
    "\n",
    "def _assert_trainable(model):\n",
    "    if not any(p.requires_grad for p in model.parameters()):\n",
    "        raise RuntimeError(\"Model has no trainable parameters (requires_grad=False).\")\n",
    "\n",
    "# ============== Models & Factories ==============\n",
    "\n",
    "class MLP2(nn.Module):\n",
    "    def __init__(self, d, h1=16, h2=16):\n",
    "        super().__init__()\n",
    "        self.l1  = nn.Linear(d, h1, bias=True)\n",
    "        self.l2  = nn.Linear(h1, h2, bias=True)\n",
    "        self.out = nn.Linear(h2, 1,  bias=True)\n",
    "        self.act = nn.ReLU()\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        nn.init.kaiming_normal_(self.l1.weight, nonlinearity='relu'); nn.init.zeros_(self.l1.bias)\n",
    "        nn.init.kaiming_normal_(self.l2.weight, nonlinearity='relu'); nn.init.zeros_(self.l2.bias)\n",
    "        nn.init.xavier_uniform_(self.out.weight, gain=0.1);          nn.init.zeros_(self.out.bias)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.float()\n",
    "        x = self.act(self.l1(x))\n",
    "        x = self.act(self.l2(x))\n",
    "        return self.out(x)\n",
    "\n",
    "def make_linear_factory(d):\n",
    "    def factory():\n",
    "        m = nn.Linear(d, 1, bias=True)\n",
    "        with torch.no_grad():\n",
    "            nn.init.xavier_uniform_(m.weight, gain=0.1)\n",
    "            nn.init.zeros_(m.bias)\n",
    "        return m\n",
    "    return factory\n",
    "\n",
    "def make_mlp_factory(d, hidden=(16,16)):\n",
    "    h1, h2 = hidden\n",
    "    def factory():\n",
    "        return MLP2(d, h1=h1, h2=h2)\n",
    "    return factory\n",
    "\n",
    "def make_model_factory_by_kind(kind, d):\n",
    "    kind = (kind or \"linear\").lower()\n",
    "    if kind == \"mlp\":\n",
    "        return make_mlp_factory(d, hidden=(16,16))\n",
    "    return make_linear_factory(d)\n",
    "\n",
    "def zero_init_last_layer(model):\n",
    "    with torch.no_grad():\n",
    "        if isinstance(model, nn.Linear):\n",
    "            model.weight.zero_(); model.bias.zero_()\n",
    "        elif hasattr(model, \"out\") and isinstance(model.out, nn.Linear):\n",
    "            model.out.weight.zero_(); model.out.bias.zero_()\n",
    "        else:\n",
    "            raise RuntimeError(\"Cannot locate final linear layer to zero-init.\")\n",
    "\n",
    "def set_train_last_layer_only(model, last_only=False):\n",
    "    if isinstance(model, nn.Linear):\n",
    "        for p in model.parameters(): p.requires_grad_(True)\n",
    "        return\n",
    "    if not hasattr(model, \"out\") or not isinstance(model.out, nn.Linear):\n",
    "        for p in model.parameters(): p.requires_grad_(True)\n",
    "        return\n",
    "    for name, p in model.named_parameters():\n",
    "        if last_only and not name.startswith(\"out.\"):\n",
    "            p.requires_grad_(False)\n",
    "        else:\n",
    "            p.requires_grad_(True)\n",
    "\n",
    "class ParamEMA:\n",
    "    def __init__(self, model):\n",
    "        self.shadow = {k: v.detach().clone() for k, v in model.state_dict().items()}\n",
    "        self.t = 0\n",
    "    @torch.no_grad()\n",
    "    def update(self, model):\n",
    "        self.t += 1\n",
    "        for k, v in model.state_dict().items():\n",
    "            self.shadow[k] += (v.detach() - self.shadow[k]) / float(self.t)\n",
    "    @torch.no_grad()\n",
    "    def copy_to(self, model):\n",
    "        model.load_state_dict(self.shadow, strict=True)\n",
    "\n",
    "# ============== θ* on TARGET (tightened α − ε0T) ==============\n",
    "\n",
    "def pretrain_unconstrained(model, x0, x1, steps=200, lr=1e-2, weight_decay=1e-5):\n",
    "    model.train(); _assert_trainable(model)\n",
    "    x0 = x0.float(); x1 = x1.float()\n",
    "    y0 = -torch.ones(x0.size(0), 1, dtype=x0.dtype); y1 = torch.ones(x1.size(0), 1, dtype=x1.dtype)\n",
    "    X = torch.cat([x0, x1], dim=0); Y = torch.cat([y0, y1], dim=0)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
    "    for _ in range(steps):\n",
    "        opt.zero_grad(set_to_none=True)\n",
    "        with torch.enable_grad():\n",
    "            h = model(X)\n",
    "            loss = torch.log1p(torch.exp(-Y * h)).mean()\n",
    "            if not loss.requires_grad:\n",
    "                raise RuntimeError(\"Pretrain loss lost grad path (check global no_grad).\")\n",
    "            loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)\n",
    "        opt.step()\n",
    "\n",
    "def train_np_sigmoid(alpha, k, T, eta_theta, eta_lambda, delta,\n",
    "                     x0, x1, model_factory, lam_cap=5.0,\n",
    "                     pretrain_steps=200, wd=1e-5, rho=1.0):\n",
    "    # Augmented Lagrangian for θ*: L = R1 + λ g + ρ·relu(g)^2,  g=R0−α−δ\n",
    "    model = model_factory()\n",
    "    for p in model.parameters(): p.requires_grad_(True)\n",
    "    model.train(); _assert_trainable(model)\n",
    "    x0 = x0.float(); x1 = x1.float()\n",
    "\n",
    "    pretrain_unconstrained(model, x0, x1, steps=pretrain_steps, lr=1e-2, weight_decay=wd)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=eta_theta, weight_decay=wd)\n",
    "    lam = 0.0\n",
    "    ema = ParamEMA(model)\n",
    "\n",
    "    for _ in range(T):\n",
    "        opt.zero_grad(set_to_none=True)\n",
    "        with torch.enable_grad():\n",
    "            h0 = model(x0); h1 = model(x1)\n",
    "            R0_hat = phi_pos(h0, k).mean()\n",
    "            R1_hat = phi_neg(h1, k).mean()\n",
    "            g_val  = R0_hat - alpha - delta\n",
    "            L = R1_hat + lam * g_val + rho * torch.relu(g_val)**2\n",
    "            if not L.requires_grad:\n",
    "                raise RuntimeError(\"NP loss lost grad path.\")\n",
    "            L.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)\n",
    "        opt.step()\n",
    "        with torch.no_grad():\n",
    "            ema.update(model)\n",
    "            lam = max(lam + eta_lambda * g_val.item(), 0.0)\n",
    "            if lam_cap is not None: lam = min(lam, float(lam_cap))\n",
    "\n",
    "    model_avg = model_factory()\n",
    "    ema.copy_to(model_avg)\n",
    "    with torch.no_grad():\n",
    "        h0 = model_avg(x0); h1 = model_avg(x1)\n",
    "        R0_tr = phi_pos(h0, k).mean().item()\n",
    "        R1_tr = phi_neg(h1, k).mean().item()\n",
    "        typeI_tr = (h0 > 0).float().mean().item()\n",
    "        typeII_tr = (h1 <= 0).float().mean().item()\n",
    "    return model_avg, R0_tr, R1_tr, typeI_tr, typeII_tr\n",
    "\n",
    "def solve_theta_star_T_tight(x0T, x1T, alpha, eps0T, k, T_inner,\n",
    "                             eta_theta, eta_lambda, delta, model_factory, rho=1.0):\n",
    "    alpha_tight = max(alpha - eps0T, 0.0)\n",
    "    model, _, R1_star, *_ = train_np_sigmoid(\n",
    "        alpha=alpha_tight, k=k, T=T_inner,\n",
    "        eta_theta=eta_theta, eta_lambda=eta_lambda, delta=delta,\n",
    "        x0=x0T, x1=x1T, model_factory=model_factory, lam_cap=5.0, rho=rho\n",
    "    )\n",
    "    return model, float(R1_star)\n",
    "\n",
    "# ============== α′ optimization (same model kind; safe α′) ==============\n",
    "\n",
    "def _alpha_from_u(u, alpha, tau=1e-6):\n",
    "    return float(alpha) + (1.0 - float(alpha) - tau) * torch.sigmoid(torch.tensor(u)).item()\n",
    "\n",
    "def optimize_alpha_prime(\n",
    "    alpha,\n",
    "    x0T, x1T, x0S,\n",
    "    model_kind,\n",
    "    k=20.0,\n",
    "    c_eps0T=0.5, c_eps0S=0.5, c_eps1T=1.0,\n",
    "    T_outer=4000,\n",
    "    eta_theta=0.02,\n",
    "    eta_alpha=0.05,\n",
    "    eta_lambda1=1.0, eta_lambda2=2.0, eta_lambda3=1.0,\n",
    "    delta=0.0, T_inner_np=3000,\n",
    "    alpha_loop_train_last_layer_only=False,\n",
    "    alpha_loop_weight_decay=1e-4\n",
    "):\n",
    "    n0T, d = x0T.shape; n1T, _ = x1T.shape; n0S, _ = x0S.shape\n",
    "    eps0T = eps_from_n(n0T, c_eps0T)\n",
    "    eps0S = eps_from_n(n0S, c_eps0S)\n",
    "    eps1T = eps_from_n(n1T, c_eps1T)\n",
    "\n",
    "    model_factory = make_model_factory_by_kind(model_kind, d)\n",
    "    rho_theta = 2.0 if model_kind.lower() == \"mlp\" else 1.0\n",
    "    eta_theta_theta = 0.02 if model_kind.lower() == \"mlp\" else 0.1\n",
    "\n",
    "    theta_star_model, R1_star = solve_theta_star_T_tight(\n",
    "        x0T, x1T, alpha, eps0T, k, T_inner_np,\n",
    "        eta_theta=eta_theta_theta, eta_lambda=eta_lambda1, delta=delta,\n",
    "        model_factory=model_factory, rho=rho_theta\n",
    "    )\n",
    "\n",
    "    model_alpha = model_factory()\n",
    "    zero_init_last_layer(model_alpha)\n",
    "    set_train_last_layer_only(model_alpha, last_only=alpha_loop_train_last_layer_only)\n",
    "    model_alpha.train(); _assert_trainable(model_alpha)\n",
    "\n",
    "    opt_theta = torch.optim.SGD(filter(lambda p: p.requires_grad, model_alpha.parameters()),\n",
    "                                lr=eta_theta, weight_decay=alpha_loop_weight_decay)\n",
    "\n",
    "    def R0_T_hat(): return phi_pos(model_alpha(x0T.float()), k).mean()\n",
    "    def R1_T_hat(): return phi_neg(model_alpha(x1T.float()), k).mean()\n",
    "    def R0_S_hat(): return phi_pos(model_alpha(x0S.float()), k).mean()\n",
    "\n",
    "    u = 0.0\n",
    "    alpha_prime = _alpha_from_u(u, alpha)\n",
    "\n",
    "    lam1 = lam2 = lam3 = 0.0\n",
    "    alpha_prime_avg = 0.0\n",
    "    alpha_prime_last = alpha_prime\n",
    "\n",
    "    with torch.no_grad():\n",
    "        g2_0 = (R0_S_hat() - alpha_prime - eps0S).item()\n",
    "        if g2_0 > 0:\n",
    "            lam2 = min(1.5, 5.0)\n",
    "\n",
    "    for t in range(1, T_outer + 1):\n",
    "        opt_theta.zero_grad(set_to_none=True)\n",
    "        with torch.enable_grad():\n",
    "            g1 = R0_T_hat() - alpha - eps0T - delta\n",
    "            g2 = R0_S_hat() - alpha_prime - eps0S\n",
    "            g3 = R1_T_hat() - R1_star - eps1T\n",
    "            L  = alpha_prime + lam1*g1 + lam2*g2 + lam3*g3\n",
    "            if not L.requires_grad:\n",
    "                raise RuntimeError(\"Alpha-prime objective lost grad path.\")\n",
    "            L.backward()\n",
    "        opt_theta.step()\n",
    "\n",
    "        with torch.no_grad():\n",
    "            lam1 = min(max(lam1 + eta_lambda1 * g1.item(), 0.0), 5.0)\n",
    "            lam2 = min(max(lam2 + eta_lambda2 * g2.item(), 0.0), 5.0)\n",
    "            lam3 = min(max(lam3 + eta_lambda3 * g3.item(), 0.0), 5.0)\n",
    "\n",
    "            u += eta_alpha * (lam2 - 1.0)\n",
    "            alpha_prime = _alpha_from_u(u, alpha)\n",
    "\n",
    "            alpha_prime_last = alpha_prime\n",
    "            alpha_prime_avg += (alpha_prime - alpha_prime_avg) / t\n",
    "\n",
    "    alpha_prime_avg = max(alpha_prime_avg, float(alpha))\n",
    "    alpha_prime_last = max(alpha_prime_last, float(alpha))\n",
    "\n",
    "    return {\n",
    "        \"alpha_prime\": alpha_prime_avg,\n",
    "        \"alpha_prime_last\": alpha_prime_last,\n",
    "        \"eps0T\": eps0T, \"eps0S\": eps0S, \"eps1T\": eps1T,\n",
    "        \"R1_star_train\": R1_star,\n",
    "        \"theta_star_model\": theta_star_model,\n",
    "        \"train\": {\n",
    "            \"x0T\": x0T, \"x1T\": x1T, \"x0S\": x0S,\n",
    "            \"alpha\": alpha, \"k\": k, \"delta\": delta, \"d\": d, \"model_kind\": model_kind\n",
    "        }\n",
    "    }\n",
    "\n",
    "# ============== Stage 2 & Stage 3 ==============\n",
    "\n",
    "def solve_min_R1_given_alpha_prime(alpha_prime_hat, meta, T=2000,\n",
    "                                   eta_theta=0.01,\n",
    "                                   eta_lambda1=5.0, eta_lambda2=5.0, eta_lambda3=5.0,\n",
    "                                   wd=1e-5):\n",
    "    x0T = meta[\"train\"][\"x0T\"]; x1T = meta[\"train\"][\"x1T\"]; x0S = meta[\"train\"][\"x0S\"]\n",
    "    alpha = meta[\"train\"][\"alpha\"]; k = meta[\"train\"][\"k\"]; delta = meta[\"train\"][\"delta\"]\n",
    "    eps0T = meta[\"eps0T\"]; eps0S = meta[\"eps0S\"]; eps1T = meta[\"eps1T\"]\n",
    "    d = meta[\"train\"][\"d\"]; model_kind = meta[\"train\"][\"model_kind\"]\n",
    "\n",
    "    model_factory = make_model_factory_by_kind(model_kind, d)\n",
    "    model = model_factory()\n",
    "    for p in model.parameters(): p.requires_grad_(True)\n",
    "    model.train(); _assert_trainable(model)\n",
    "\n",
    "    opt_theta = torch.optim.Adam(model.parameters(), lr=eta_theta, weight_decay=wd)\n",
    "    lam1 = lam2 = lam3 = 0.0\n",
    "    ema = ParamEMA(model)\n",
    "\n",
    "    for _ in range(1, T + 1):\n",
    "        opt_theta.zero_grad(set_to_none=True)\n",
    "        with torch.enable_grad():\n",
    "            R0T = phi_pos(model(x0T.float()), k).mean()\n",
    "            R0S = phi_pos(model(x0S.float()), k).mean()\n",
    "            R1T = phi_neg(model(x1T.float()), k).mean()\n",
    "            g1 = R0T - alpha - eps0T - delta\n",
    "            g2 = R0S - alpha_prime_hat - eps0S\n",
    "            g3 = R1T - meta[\"R1_star_train\"] - eps1T\n",
    "            L = R1T + lam1*g1 + lam2*g2 + lam3*g3\n",
    "            if not L.requires_grad:\n",
    "                raise RuntimeError(\"Stage2 objective lost grad path.\")\n",
    "            L.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)\n",
    "        opt_theta.step()\n",
    "        with torch.no_grad():\n",
    "            ema.update(model)\n",
    "            lam1 = min(max(lam1 + eta_lambda1 * g1.item(), 0.0), 5.0)\n",
    "            lam2 = min(max(lam2 + eta_lambda2 * g2.item(), 0.0), 5.0)\n",
    "            lam3 = min(max(lam3 + eta_lambda3 * g3.item(), 0.0), 5.0)\n",
    "\n",
    "    theta_avg_model = model_factory()\n",
    "    ema.copy_to(theta_avg_model)\n",
    "    with torch.no_grad():\n",
    "        R1_train_val = phi_neg(theta_avg_model(x1T.float()), k).mean().item()\n",
    "    return theta_avg_model, R1_train_val\n",
    "\n",
    "def solve_min_R1S_final(meta_alpha, R1T_star_Halph, x1S, T=3000,\n",
    "                        eta_theta=0.01, eta_lambda=5.0, lam_cap=5.0, wd=1e-5):\n",
    "    k = meta_alpha[\"train\"][\"k\"]\n",
    "    x0T = meta_alpha[\"train\"][\"x0T\"]; x1T = meta_alpha[\"train\"][\"x1T\"]; x0S = meta_alpha[\"train\"][\"x0S\"]\n",
    "    alpha = meta_alpha[\"train\"][\"alpha\"]; delta = meta_alpha[\"train\"][\"delta\"]\n",
    "    eps0T = meta_alpha[\"eps0T\"]; eps0S = meta_alpha[\"eps0S\"]; eps1T = meta_alpha[\"eps1T\"]\n",
    "    d = meta_alpha[\"train\"][\"d\"]; model_kind = meta_alpha[\"train\"][\"model_kind\"]\n",
    "    alpha_prime_hat = float(meta_alpha[\"alpha_prime\"])\n",
    "    R1T_baseline_tight = float(meta_alpha[\"R1_star_train\"])\n",
    "\n",
    "    model_factory = make_model_factory_by_kind(model_kind, d)\n",
    "    model = model_factory()\n",
    "    for p in model.parameters(): p.requires_grad_(True)\n",
    "    model.train(); _assert_trainable(model)\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=eta_theta, weight_decay=wd)\n",
    "    lam = 0.0\n",
    "    ema = ParamEMA(model)\n",
    "\n",
    "    def g_max(m):\n",
    "        R0T = phi_pos(m(x0T.float()), k).mean()\n",
    "        R0S = phi_pos(m(x0S.float()), k).mean()\n",
    "        R1T = phi_neg(m(x1T.float()), k).mean()\n",
    "        g1 = R0T - (alpha + eps0T + delta)\n",
    "        g2 = R0S - (alpha_prime_hat + eps0S)\n",
    "        g3 = R1T - R1T_baseline_tight - eps1T\n",
    "        g4 = R1T - R1T_star_Halph   - eps1T\n",
    "        return torch.stack([g1, g2, g3, g4]).max()\n",
    "\n",
    "    def R1S(m): return phi_neg(m(x1S.float()), k).mean()\n",
    "\n",
    "    for _ in range(1, T + 1):\n",
    "        opt.zero_grad(set_to_none=True)\n",
    "        with torch.enable_grad():\n",
    "            obj  = R1S(model)\n",
    "            gval = g_max(model)\n",
    "            L = obj + lam * gval\n",
    "            if not L.requires_grad:\n",
    "                raise RuntimeError(\"Final-stage objective lost grad path.\")\n",
    "            L.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)\n",
    "        opt.step()\n",
    "        with torch.no_grad():\n",
    "            ema.update(model)\n",
    "            lam = max(lam + eta_lambda * gval.item(), 0.0)\n",
    "            lam = min(lam, lam_cap)\n",
    "\n",
    "    theta_hat = model_factory()\n",
    "    ema.copy_to(theta_hat)\n",
    "    with torch.no_grad():\n",
    "        R1S_hat = R1S(theta_hat).item()\n",
    "    return theta_hat, R1S_hat\n",
    "\n",
    "# ============== Baseline helper ==============\n",
    "\n",
    "def train_theta_star_on_dataset(alpha, k, T_inner, eta_theta, eta_lambda, delta,\n",
    "                                x0, x1, c_eps0, model_factory, rho=1.0):\n",
    "    n0 = x0.shape[0]\n",
    "    eps0 = eps_from_n(n0, c_eps0)\n",
    "    alpha_tight = max(alpha - eps0, 0.0)\n",
    "    model, _, _, _, _ = train_np_sigmoid(\n",
    "        alpha=alpha_tight, k=k, T=T_inner, eta_theta=eta_theta, eta_lambda=eta_lambda,\n",
    "        delta=delta, x0=x0, x1=x1, model_factory=model_factory, lam_cap=5.0, rho=rho\n",
    "    )\n",
    "    return model, eps0\n",
    "\n",
    "# ============== One-run (used by sweeps) ==============\n",
    "\n",
    "@torch.no_grad()\n",
    "def run_alg_and_baselines_once(\n",
    "    x0T, x1T, x0S, x1S, x0T_test, x1T_test,\n",
    "    *, model_kind=\"mlp\",\n",
    "    alpha=0.10, k=20.0,\n",
    "    c_eps0T=0.5, c_eps0S=0.5, c_eps1T=1.0, c_eps1S=1.0,\n",
    "    T_inner_np=1500, T_alpha=4000, T_stage2=2000, T_final=3000\n",
    "):\n",
    "    d = x0T.shape[1]\n",
    "    model_factory = make_model_factory_by_kind(model_kind, d)\n",
    "    rho_baseline = 2.0 if model_kind.lower() == \"mlp\" else 1.0\n",
    "    eta_theta_star = 0.02 if model_kind.lower() == \"mlp\" else 0.1\n",
    "\n",
    "    meta = optimize_alpha_prime(\n",
    "        alpha=alpha, x0T=x0T, x1T=x1T, x0S=x0S, model_kind=model_kind, k=k,\n",
    "        c_eps0T=c_eps0T, c_eps0S=c_eps0S, c_eps1T=c_eps1T,\n",
    "        T_outer=T_alpha, eta_theta=0.02, eta_alpha=0.05,\n",
    "        eta_lambda1=1.0, eta_lambda2=2.0, eta_lambda3=1.0,\n",
    "        delta=0.0, T_inner_np=T_inner_np,\n",
    "        alpha_loop_train_last_layer_only=False, alpha_loop_weight_decay=1e-4\n",
    "    )\n",
    "\n",
    "    theta_alphaT, R1T_star_Halph = solve_min_R1_given_alpha_prime(\n",
    "        alpha_prime_hat=meta[\"alpha_prime\"], meta=meta,\n",
    "        T=T_stage2, eta_theta=0.01,\n",
    "        eta_lambda1=5.0, eta_lambda2=5.0, eta_lambda3=5.0\n",
    "    )\n",
    "\n",
    "    R1S_star_proxy = phi_neg(theta_alphaT(x1S.float()), meta[\"train\"][\"k\"]).mean().item()\n",
    "\n",
    "    theta_hat, R1S_hat_theta = solve_min_R1S_final(\n",
    "        meta_alpha=meta, R1T_star_Halph=R1T_star_Halph,\n",
    "        x1S=x1S, T=T_final, eta_theta=0.01, eta_lambda=5.0\n",
    "    )\n",
    "\n",
    "    eps1S = eps_from_n(x1S.shape[0], c_eps1S)\n",
    "    use_alphaT = (R1S_hat_theta - R1S_star_proxy > eps1S)\n",
    "    chosen_model = theta_alphaT if use_alphaT else theta_hat\n",
    "\n",
    "    theta_star_T = meta[\"theta_star_model\"]  # target-only\n",
    "    theta_star_S, _ = train_theta_star_on_dataset(\n",
    "        alpha=alpha, k=k, T_inner=T_inner_np,\n",
    "        eta_theta=eta_theta_star, eta_lambda=1.0, delta=0.0,\n",
    "        x0=x0S, x1=x1S, c_eps0=c_eps0S, model_factory=model_factory, rho=rho_baseline\n",
    "    )\n",
    "    x0Agg = torch.cat([x0T, x0S], dim=0)\n",
    "    x1Agg = torch.cat([x1T, x1S], dim=0)\n",
    "    theta_star_Agg, _ = train_theta_star_on_dataset(\n",
    "        alpha=alpha, k=k, T_inner=T_inner_np,\n",
    "        eta_theta=eta_theta_star, eta_lambda=1.0, delta=0.0,\n",
    "        x0=x0Agg, x1=x1Agg, c_eps0=c_eps0T, model_factory=model_factory, rho=rho_baseline\n",
    "    )\n",
    "\n",
    "    metrics = {}\n",
    "    metrics[\"ours\"] = evaluate_target_test(chosen_model, k, x0T_test, x1T_test)\n",
    "    metrics[\"tgt\"]  = evaluate_target_test(theta_star_T, k, x0T_test, x1T_test)\n",
    "    metrics[\"src\"]  = evaluate_target_test(theta_star_S, k, x0T_test, x1T_test)\n",
    "    metrics[\"agg\"]  = evaluate_target_test(theta_star_Agg, k, x0T_test, x1T_test)\n",
    "    return metrics\n",
    "\n",
    "# ============== Data generators ==============\n",
    "\n",
    "def gen_gaussian_data(n0T, n1T, n0S, n1S, Ntest,\n",
    "                      d=1,\n",
    "                      t0T=-10.0, c0T=1.0, t1T=1.0, c1T=1.0,\n",
    "                      t0S=0.0,  c0S=1.0,  t1S=1.0, c1S=1.0,\n",
    "                      seed_base=0):\n",
    "    x0T, x1T = make_pair_sep(n0T, n1T, d, t0T, c0T, t1T, c1T, seed=seed_base + 17)\n",
    "    x0S, x1S = make_pair_sep(n0S, n1S, d, t0S, c0S, t1S, c1S, seed=seed_base + 31)\n",
    "    x0T_test, x1T_test = make_pair_sep(Ntest, Ntest, d, t0T, c0T, t1T, c1T, seed=seed_base + 59)\n",
    "    return x0T, x1T, x0S, x1S, x0T_test, x1T_test\n",
    "\n",
    "# ---- ClimSim helpers (kept intact) ----\n",
    "\n",
    "def resolve_climsim_root():\n",
    "    env = os.getenv(\"CLIMSIM_ROOT\", None)\n",
    "    if env and os.path.isdir(env): return env\n",
    "    try:\n",
    "        here = os.path.dirname(os.path.abspath(__file__))\n",
    "    except NameError:\n",
    "        here = os.getcwd()\n",
    "    for cand in [os.path.join(here, \"data\", \"climsim_data\"),\n",
    "                 os.path.join(os.path.dirname(here), \"data\", \"climsim_data\"),\n",
    "                 os.path.join(os.getcwd(), \"data\", \"climsim_data\")]:\n",
    "        if os.path.isdir(cand): return cand\n",
    "    return None\n",
    "\n",
    "def try_import_climsim():\n",
    "    try:\n",
    "        from climsim_data_import import load_climsim_data, prepare_climsim_data\n",
    "        return load_climsim_data, prepare_climsim_data\n",
    "    except Exception as e:\n",
    "        print(f\"[warn] ClimSim import failed: {e}\")\n",
    "        return None, None\n",
    "\n",
    "def gen_climsim_data(n0T, n1T, n0S, n1S, Ntest, d, seed_base=0):\n",
    "    load_climsim_data, prepare_climsim_data = try_import_climsim()\n",
    "    if load_climsim_data is None:\n",
    "        raise RuntimeError(\"ClimSim unavailable. Set use_gaussian=True or use_nasa=True for sweeps.\")\n",
    "    root = resolve_climsim_root()\n",
    "    cfg = {\n",
    "        \"data_frequency\": \"daily\",\n",
    "        \"data_mode\": \"cluster_4\",\n",
    "        \"targets\": [26], \"sources\": [27],\n",
    "        \"num_target_normal_training\":  n0T,\n",
    "        \"num_target_abnormal_training\": n1T,\n",
    "        \"num_source_normal\":           n0S,\n",
    "        \"num_source_abnormal\":         n1S,\n",
    "        \"num_target_normal_test\":  Ntest,\n",
    "        \"num_target_abnormal_test\": Ntest,\n",
    "        \"input_dim\": d\n",
    "    }\n",
    "    data = load_climsim_data(root_path=root)\n",
    "    x0T, x1T, x0S, x1S, x0T_test, x1T_test = prepare_climsim_data(\n",
    "        climsim_data=data, config=cfg, seed=seed_base\n",
    "    )\n",
    "    with torch.no_grad():\n",
    "        mean = torch.cat([x0T.float(), x1T.float(), x0S.float(), x1S.float()], dim=0).mean(0, keepdim=True)\n",
    "        std  = torch.cat([x0T.float(), x1T.float(), x0S.float(), x1S.float()], dim=0).std(0, unbiased=False, keepdim=True).clamp_min(1e-6)\n",
    "        x0T = (x0T.float() - mean) / std\n",
    "        x1T = (x1T.float() - mean) / std\n",
    "        x0S = (x0S.float() - mean) / std\n",
    "        x1S = (x1S.float() - mean) / std\n",
    "        x0T_test = (x0T_test.float() - mean) / std\n",
    "        x1T_test = (x1T_test.float() - mean) / std\n",
    "    return x0T, x1T, x0S, x1S, x0T_test, x1T_test\n",
    "\n",
    "# ---- NASA helpers (new) ----\n",
    "\n",
    "def resolve_nasa_root():\n",
    "    env = os.getenv(\"NASA_ROOT\", None)\n",
    "    if env and os.path.isdir(env):\n",
    "        return env\n",
    "    try:\n",
    "        here = os.path.dirname(os.path.abspath(__file__))\n",
    "    except NameError:\n",
    "        here = os.getcwd()\n",
    "    for cand in [\n",
    "        os.path.join(here, \"data\", \"nasa_rain\"),\n",
    "        os.path.join(os.path.dirname(here), \"data\", \"nasa_rain\"),\n",
    "        os.path.join(os.getcwd(), \"data\", \"nasa_rain\"),\n",
    "    ]:\n",
    "        if os.path.isdir(cand):\n",
    "            return cand\n",
    "    return None\n",
    "\n",
    "def load_nasa_data(root_path=None):\n",
    "    root = root_path or resolve_nasa_root()\n",
    "    if root is None:\n",
    "        raise RuntimeError(\"NASA data not found. Set NASA_ROOT or place files in data/nasa_rain.\")\n",
    "    tgt_path = os.path.join(root, \"nasa_15years_target_A.csv\")\n",
    "    src_path = os.path.join(root, \"nasa_15years_source_B.csv\")\n",
    "    if not os.path.isfile(tgt_path) or not os.path.isfile(src_path):\n",
    "        raise RuntimeError(f\"NASA CSVs not found under {root} (expected nasa_15years_target_A.csv and nasa_15years_source_B.csv)\")\n",
    "    with open(tgt_path, \"r\") as f:\n",
    "        target_df = pd.read_csv(f)\n",
    "    with open(src_path, \"r\") as f:\n",
    "        source_df = pd.read_csv(f)\n",
    "    if \"rain_label\" not in target_df.columns or \"rain_label\" not in source_df.columns:\n",
    "        raise RuntimeError(\"NASA CSVs must contain a 'rain_label' column (0/1).\")\n",
    "    return target_df, source_df\n",
    "\n",
    "def prepare_nasa_data(target_df, source_df, config, seed=None):\n",
    "    # Shuffle deterministically\n",
    "    if seed is None:\n",
    "        seed = 0\n",
    "    target_df = target_df.sample(frac=1.0, random_state=seed).reset_index(drop=True)\n",
    "    source_df = source_df.sample(frac=1.0, random_state=seed).reset_index(drop=True)\n",
    "\n",
    "    # Splits by label\n",
    "    tgt_norm = target_df[target_df[\"rain_label\"] == 0]\n",
    "    tgt_abn  = target_df[target_df[\"rain_label\"] == 1]\n",
    "    src_norm = source_df[source_df[\"rain_label\"] == 0]\n",
    "    src_abn  = source_df[source_df[\"rain_label\"] == 1]\n",
    "\n",
    "    n0T_tr = int(config[\"num_target_normal_training\"])\n",
    "    n1T_tr = int(config[\"num_target_abnormal_training\"])\n",
    "    n0S    = int(config[\"num_source_normal\"])\n",
    "    n1S    = int(config[\"num_source_abnormal\"])\n",
    "    n0T_te = int(config[\"num_target_normal_test\"])\n",
    "    n1T_te = int(config[\"num_target_abnormal_test\"])\n",
    "\n",
    "    if n0T_tr + n0T_te > len(tgt_norm):\n",
    "        raise ValueError(\"num_target_normal_training + num_target_normal_test exceeds available target normals\")\n",
    "    if n1T_tr + n1T_te > len(tgt_abn):\n",
    "        raise ValueError(\"num_target_abnormal_training + num_target_abnormal_test exceeds available target abnormals\")\n",
    "    if n0S > len(src_norm):\n",
    "        raise ValueError(\"num_source_normal exceeds available source normals\")\n",
    "    if n1S > len(src_abn):\n",
    "        raise ValueError(\"num_source_abnormal exceeds available source abnormals\")\n",
    "\n",
    "    # Take train/test\n",
    "    tgt_norm_train = tgt_norm.iloc[:n0T_tr]\n",
    "    tgt_norm_test  = tgt_norm.iloc[len(tgt_norm)-n0T_te:]\n",
    "    tgt_abn_train  = tgt_abn.iloc[:n1T_tr]\n",
    "    tgt_abn_test   = tgt_abn.iloc[len(tgt_abn)-n1T_te:]\n",
    "    src_norm_train = src_norm.iloc[:n0S]\n",
    "    src_abn_train  = src_abn.iloc[:n1S]\n",
    "\n",
    "    # Feature columns: all except label; keep first input_dim if provided\n",
    "    all_feat_cols = [c for c in target_df.columns if c != \"rain_label\"]\n",
    "    d_req = int(config.get(\"input_dim\", len(all_feat_cols)))\n",
    "    feat_cols = all_feat_cols[:d_req]\n",
    "\n",
    "    def toX(df):\n",
    "        return torch.tensor(df[feat_cols].values, dtype=torch.float32)\n",
    "\n",
    "    x0T = toX(tgt_norm_train)\n",
    "    x1T = toX(tgt_abn_train)\n",
    "    x0S = toX(src_norm_train)\n",
    "    x1S = toX(src_abn_train)\n",
    "    x0T_test = toX(tgt_norm_test)\n",
    "    x1T_test = toX(tgt_abn_test)\n",
    "\n",
    "    # Standardize using TRAIN sets (same scheme as ClimSim)\n",
    "    with torch.no_grad():\n",
    "        train_all = torch.cat([x0T, x1T, x0S, x1S], dim=0)\n",
    "        mean = train_all.mean(0, keepdim=True)\n",
    "        std  = train_all.std(0, unbiased=False, keepdim=True).clamp_min(1e-6)\n",
    "        x0T      = (x0T      - mean) / std\n",
    "        x1T      = (x1T      - mean) / std\n",
    "        x0S      = (x0S      - mean) / std\n",
    "        x1S      = (x1S      - mean) / std\n",
    "        x0T_test = (x0T_test - mean) / std\n",
    "        x1T_test = (x1T_test - mean) / std\n",
    "\n",
    "    return x0T, x1T, x0S, x1S, x0T_test, x1T_test\n",
    "\n",
    "def gen_nasa_data(n0T, n1T, n0S, n1S, Ntest, d=6, seed_base=0):\n",
    "    target_df, source_df = load_nasa_data(root_path=resolve_nasa_root())\n",
    "    cfg = {\n",
    "        \"num_target_normal_training\":  int(n0T),\n",
    "        \"num_target_abnormal_training\":int(n1T),\n",
    "        \"num_source_normal\":           int(n0S),\n",
    "        \"num_source_abnormal\":         int(n1S),\n",
    "        \"num_target_normal_test\":      int(Ntest),\n",
    "        \"num_target_abnormal_test\":    int(Ntest),\n",
    "        \"input_dim\": int(d),\n",
    "    }\n",
    "    return prepare_nasa_data(target_df, source_df, cfg, seed=seed_base)\n",
    "\n",
    "# ============== SWEEPS (fixed sets so baselines behave as expected) ==============\n",
    "\n",
    "def sweep_target_samples_vs_errors(\n",
    "    nT_values,\n",
    "    *,\n",
    "    n0S=1000, n1S=1000,\n",
    "    trials=10,\n",
    "    use_gaussian=True,\n",
    "    use_nasa=False,          # NEW\n",
    "    gaussian_kwargs=None,\n",
    "    climsim_d=124,\n",
    "    nasa_d=6,                # NEW\n",
    "    Ntest=1700,\n",
    "    model_kind=\"mlp\",\n",
    "    alpha=0.10, k=20.0,\n",
    "    c_eps0T=0.5, c_eps0S=0.5, c_eps1T=1.0, c_eps1S=1.0,\n",
    "    T_inner_np=1500, T_alpha=4000, T_stage2=2000, T_final=3000,\n",
    "    base_seed=0\n",
    "):\n",
    "    \"\"\"\n",
    "    Vary n_{0,T}=n_{1,T}; within each trial:\n",
    "      - Keep SOURCE training set fixed across all nT\n",
    "      - Keep TARGET TEST set fixed across all nT\n",
    "      => θ*_S curve is ~flat vs nT.\n",
    "    \"\"\"\n",
    "    gaussian_kwargs = gaussian_kwargs or {}\n",
    "    methods = [\"ours\",\"tgt\",\"src\",\"agg\"]\n",
    "    sums_typeI  = {m: [0.0]*len(nT_values) for m in methods}\n",
    "    sums_typeII = {m: [0.0]*len(nT_values) for m in methods}\n",
    "\n",
    "    for trial in range(trials):\n",
    "        seed_trial = base_seed + 1000*trial\n",
    "\n",
    "        if use_gaussian:\n",
    "            d = gaussian_kwargs.get(\"d\", 1)\n",
    "            t0S = gaussian_kwargs.get(\"t0S\", 0.0); c0S = gaussian_kwargs.get(\"c0S\", 1.0)\n",
    "            t1S = gaussian_kwargs.get(\"t1S\", 1.0); c1S = gaussian_kwargs.get(\"c1S\", 1.0)\n",
    "            t0T = gaussian_kwargs.get(\"t0T\", -10.0); c0T = gaussian_kwargs.get(\"c0T\", 1.0)\n",
    "            t1T = gaussian_kwargs.get(\"t1T\", 1.0);  c1T = gaussian_kwargs.get(\"c1T\", 1.0)\n",
    "\n",
    "            x0S_fixed, x1S_fixed = make_pair_sep(n0S, n1S, d, t0S, c0S, t1S, c1S, seed=seed_trial+31)\n",
    "            x0T_test_fixed, x1T_test_fixed = make_pair_sep(Ntest, Ntest, d, t0T, c0T, t1T, c1T, seed=seed_trial+59)\n",
    "        elif use_nasa:\n",
    "            x0T_tmp, x1T_tmp, x0S_fixed, x1S_fixed, x0T_test_fixed, x1T_test_fixed = gen_nasa_data(\n",
    "                n0T=nT_values[0], n1T=nT_values[0], n0S=n0S, n1S=n1S, Ntest=Ntest, d=nasa_d, seed_base=seed_trial\n",
    "            )\n",
    "            d = nasa_d\n",
    "        else:\n",
    "            x0T_tmp, x1T_tmp, x0S_fixed, x1S_fixed, x0T_test_fixed, x1T_test_fixed = gen_climsim_data(\n",
    "                n0T=nT_values[0], n1T=nT_values[0], n0S=n0S, n1S=n1S, Ntest=Ntest, d=climsim_d, seed_base=seed_trial\n",
    "            )\n",
    "            d = climsim_d\n",
    "\n",
    "        for i, nT in enumerate(nT_values):\n",
    "            seed_point = seed_trial + 7*i\n",
    "            if use_gaussian:\n",
    "                x0T, x1T = make_pair_sep(nT, nT, d, t0T, c0T, t1T, c1T, seed=seed_point+17)\n",
    "                x0S, x1S = x0S_fixed, x1S_fixed\n",
    "                x0T_test, x1T_test = x0T_test_fixed, x1T_test_fixed\n",
    "            elif use_nasa:\n",
    "                x0T, x1T, _, _, _, _ = gen_nasa_data(\n",
    "                    n0T=nT, n1T=nT, n0S=n0S, n1S=n1S, Ntest=Ntest, d=nasa_d, seed_base=seed_point\n",
    "                )\n",
    "                x0S, x1S = x0S_fixed, x1S_fixed\n",
    "                x0T_test, x1T_test = x0T_test_fixed, x1T_test_fixed\n",
    "            else:\n",
    "                x0T, x1T, _, _, _, _ = gen_climsim_data(\n",
    "                    n0T=nT, n1T=nT, n0S=n0S, n1S=n1S, Ntest=Ntest, d=climsim_d, seed_base=seed_point\n",
    "                )\n",
    "                x0S, x1S = x0S_fixed, x1S_fixed\n",
    "                x0T_test, x1T_test = x0T_test_fixed, x1T_test_fixed\n",
    "\n",
    "            metrics = run_alg_and_baselines_once(\n",
    "                x0T, x1T, x0S, x1S, x0T_test, x1T_test,\n",
    "                model_kind=model_kind, alpha=alpha, k=k,\n",
    "                c_eps0T=c_eps0T, c_eps0S=c_eps0S, c_eps1T=c_eps1T, c_eps1S=c_eps1S,\n",
    "                T_inner_np=T_inner_np, T_alpha=T_alpha, T_stage2=T_stage2, T_final=T_final\n",
    "            )\n",
    "            for key in methods:\n",
    "                sums_typeI[key][i]  += metrics[key][\"typeI\"]\n",
    "                sums_typeII[key][i] += metrics[key][\"typeII\"]\n",
    "\n",
    "    avg_typeI  = {k: [v/trials for v in sums_typeI[k]]  for k in methods}\n",
    "    avg_typeII = {k: [v/trials for v in sums_typeII[k]] for k in methods}\n",
    "\n",
    "    # --- Plot Type-I ---\n",
    "    fig1 = plt.figure()\n",
    "    for key, label in [(\"ours\",\"Our method\"), (\"tgt\",\"θ*_T\"), (\"src\",\"θ*_S\"), (\"agg\",\"θ*_Agg\")]:\n",
    "        plt.plot(nT_values, avg_typeI[key], label=label, marker='o')\n",
    "    plt.xlabel(\"Number of TARGET samples per class (n0T = n1T)\")\n",
    "    plt.ylabel(\"Type-I (TARGET test)\")\n",
    "    plt.title(\"Type-I vs Target Samples\")\n",
    "    plt.grid(True, alpha=0.3); plt.legend()\n",
    "\n",
    "    # --- Plot Type-II ---\n",
    "    fig2 = plt.figure()\n",
    "    for key, label in [(\"ours\",\"Our method\"), (\"tgt\",\"θ*_T\"), (\"src\",\"θ*_S\"), (\"agg\",\"θ*_Agg\")]:\n",
    "        plt.plot(nT_values, avg_typeII[key], label=label, marker='o')\n",
    "    plt.xlabel(\"Number of TARGET samples per class (n0T = n1T)\")\n",
    "    plt.ylabel(\"Type-II (TARGET test)\")\n",
    "    plt.title(\"Type-II vs Target Samples\")\n",
    "    plt.grid(True, alpha=0.3); plt.legend()\n",
    "\n",
    "    results = {\"avg_typeI\": avg_typeI, \"avg_typeII\": avg_typeII}\n",
    "    return fig1, fig2, results\n",
    "\n",
    "def sweep_source_samples_vs_errors(\n",
    "    nS_values,\n",
    "    *,\n",
    "    n0T=40, n1T=40,           # fixed target sizes\n",
    "    trials=10,\n",
    "    use_gaussian=True,\n",
    "    use_nasa=False,          # NEW\n",
    "    gaussian_kwargs=None,\n",
    "    climsim_d=124,\n",
    "    nasa_d=6,                # NEW\n",
    "    Ntest=1700,\n",
    "    model_kind=\"mlp\",\n",
    "    alpha=0.10, k=20.0,\n",
    "    c_eps0T=0.5, c_eps0S=0.5, c_eps1T=1.0, c_eps1S=1.0,\n",
    "    T_inner_np=1500, T_alpha=4000, T_stage2=2000, T_final=3000,\n",
    "    base_seed=0\n",
    "):\n",
    "    \"\"\"\n",
    "    Vary n_{0,S}=n_{1,S}; within each trial we FIX:\n",
    "      • TARGET TRAIN (x0T_fixed, x1T_fixed)\n",
    "      • TARGET TEST  (x0T_test_fixed, x1T_test_fixed)\n",
    "    So θ*_T (target-only) is unaffected by nS and appears ~flat.\n",
    "    \"\"\"\n",
    "    gaussian_kwargs = gaussian_kwargs or {}\n",
    "    methods = [\"ours\",\"tgt\",\"src\",\"agg\"]\n",
    "    sums_typeI  = {m: [0.0]*len(nS_values) for m in methods}\n",
    "    sums_typeII = {m: [0.0]*len(nS_values) for m in methods}\n",
    "\n",
    "    for trial in range(trials):\n",
    "        seed_trial = base_seed + 2000*trial\n",
    "\n",
    "        if use_gaussian:\n",
    "            d = gaussian_kwargs.get(\"d\", 1)\n",
    "            t0S = gaussian_kwargs.get(\"t0S\", 0.0); c0S = gaussian_kwargs.get(\"c0S\", 1.0)\n",
    "            t1S = gaussian_kwargs.get(\"t1S\", 1.0); c1S = gaussian_kwargs.get(\"c1S\", 1.0)\n",
    "            t0T = gaussian_kwargs.get(\"t0T\", -10.0); c0T = gaussian_kwargs.get(\"c0T\", 1.0)\n",
    "            t1T = gaussian_kwargs.get(\"t1T\", 1.0);  c1T = gaussian_kwargs.get(\"c1T\", 1.0)\n",
    "\n",
    "            x0T_fixed, x1T_fixed = make_pair_sep(n0T, n1T, d, t0T, c0T, t1T, c1T, seed=seed_trial+17)\n",
    "            x0T_test_fixed, x1T_test_fixed = make_pair_sep(Ntest, Ntest, d, t0T, c0T, t1T, c1T, seed=seed_trial+59)\n",
    "        elif use_nasa:\n",
    "            x0T_fixed, x1T_fixed, x0S_tmp, x1S_tmp, x0T_test_fixed, x1T_test_fixed = gen_nasa_data(\n",
    "                n0T=n0T, n1T=n1T, n0S=nS_values[0], n1S=nS_values[0], Ntest=Ntest, d=nasa_d, seed_base=seed_trial\n",
    "            )\n",
    "            d = nasa_d\n",
    "        else:\n",
    "            x0T_fixed, x1T_fixed, x0S_tmp, x1S_tmp, x0T_test_fixed, x1T_test_fixed = gen_climsim_data(\n",
    "                n0T=n0T, n1T=n1T, n0S=nS_values[0], n1S=nS_values[0], Ntest=Ntest, d=climsim_d, seed_base=seed_trial\n",
    "            )\n",
    "            d = climsim_d\n",
    "\n",
    "        for i, nS in enumerate(nS_values):\n",
    "            seed_point = seed_trial + 13*i\n",
    "            if use_gaussian:\n",
    "                x0S, x1S = make_pair_sep(nS, nS, d, t0S, c0S, t1S, c1S, seed=seed_point+31)\n",
    "                x0T, x1T = x0T_fixed, x1T_fixed\n",
    "                x0T_test, x1T_test = x0T_test_fixed, x1T_test_fixed\n",
    "            elif use_nasa:\n",
    "                _, _, x0S, x1S, _, _ = gen_nasa_data(\n",
    "                    n0T=n0T, n1T=n1T, n0S=nS, n1S=nS, Ntest=Ntest, d=nasa_d, seed_base=seed_point\n",
    "                )\n",
    "                x0T, x1T = x0T_fixed, x1T_fixed\n",
    "                x0T_test, x1T_test = x0T_test_fixed, x1T_test_fixed\n",
    "            else:\n",
    "                _, _, x0S, x1S, _, _ = gen_climsim_data(\n",
    "                    n0T=n0T, n1T=n1T, n0S=nS, n1S=nS, Ntest=Ntest, d=climsim_d, seed_base=seed_point\n",
    "                )\n",
    "                x0T, x1T = x0T_fixed, x1T_fixed\n",
    "                x0T_test, x1T_test = x0T_test_fixed, x1T_test_fixed\n",
    "\n",
    "            metrics = run_alg_and_baselines_once(\n",
    "                x0T, x1T, x0S, x1S, x0T_test, x1T_test,\n",
    "                model_kind=model_kind, alpha=alpha, k=k,\n",
    "                c_eps0T=c_eps0T, c_eps0S=c_eps0S, c_eps1T=c_eps1T, c_eps1S=c_eps1S,\n",
    "                T_inner_np=T_inner_np, T_alpha=T_alpha, T_stage2=T_stage2, T_final=T_final\n",
    "            )\n",
    "            for key in methods:\n",
    "                sums_typeI[key][i]  += metrics[key][\"typeI\"]\n",
    "                sums_typeII[key][i] += metrics[key][\"typeII\"]\n",
    "\n",
    "    avg_typeI  = {k: [v/trials for v in sums_typeI[k]]  for k in methods}\n",
    "    avg_typeII = {k: [v/trials for v in sums_typeII[k]] for k in methods}\n",
    "\n",
    "    fig1 = plt.figure()\n",
    "    for key, label in [(\"ours\",\"Our method\"), (\"tgt\",\"θ*_T\"), (\"src\",\"θ*_S\"), (\"agg\",\"θ*_Agg\")]:\n",
    "        plt.plot(nS_values, avg_typeI[key], label=label, marker='o')\n",
    "    plt.xlabel(\"Number of SOURCE samples per class (n0S = n1S)\")\n",
    "    plt.ylabel(\"Type-I (TARGET test)\")\n",
    "    plt.title(\"Type-I vs Source Samples\")\n",
    "    plt.grid(True, alpha=0.3); plt.legend()\n",
    "\n",
    "    fig2 = plt.figure()\n",
    "    for key, label in [(\"ours\",\"Our method\"), (\"tgt\",\"θ*_T\"), (\"src\",\"θ*_S\"), (\"agg\",\"θ*_Agg\")]:\n",
    "        plt.plot(nS_values, avg_typeII[key], label=label, marker='o')\n",
    "    plt.xlabel(\"Number of SOURCE samples per class (n0S = n1S)\")\n",
    "    plt.ylabel(\"Type-II (TARGET test)\")\n",
    "    plt.title(\"Type-II vs Source Samples\")\n",
    "    plt.grid(True, alpha=0.3); plt.legend()\n",
    "\n",
    "    results = {\"avg_typeI\": avg_typeI, \"avg_typeII\": avg_typeII}\n",
    "    return fig1, fig2, results\n",
    "\n",
    "# ============== Save helpers (CSV/JSON + figures) ==============\n",
    "\n",
    "def save_sweep_results_csv(x_values, results, x_name, out_csv_path):\n",
    "    \"\"\"\n",
    "    Save averages from sweep_*_vs_errors to a wide CSV.\n",
    "    Columns: [x_name, ours_TypeI, tgt_TypeI, src_TypeI, agg_TypeI, ours_TypeII, ...]\n",
    "    \"\"\"\n",
    "    methods = [\"ours\", \"tgt\", \"src\", \"agg\"]\n",
    "    with open(out_csv_path, \"w\", newline=\"\") as f:\n",
    "        w = csv.writer(f)\n",
    "        header = [x_name] + [f\"{m}_TypeI\" for m in methods] + [f\"{m}_TypeII\" for m in methods]\n",
    "        w.writerow(header)\n",
    "        for i, x in enumerate(x_values):\n",
    "            row = [x]\n",
    "            row += [results[\"avg_typeI\"][m][i]  for m in methods]\n",
    "            row += [results[\"avg_typeII\"][m][i] for m in methods]\n",
    "            w.writerow(row)\n",
    "\n",
    "def save_sweep_results_json(x_values, results, x_name, out_json_path):\n",
    "    \"\"\"\n",
    "    Save the same data as JSON.\n",
    "    Structure: { x_name: [...], avg_typeI: {ours: [...]...}, avg_typeII: {...} }\n",
    "    \"\"\"\n",
    "    payload = {x_name: list(x_values), \"avg_typeI\": results[\"avg_typeI\"], \"avg_typeII\": results[\"avg_typeII\"]}\n",
    "    with open(out_json_path, \"w\") as f:\n",
    "        json.dump(payload, f, indent=2)\n",
    "\n",
    "def save_figure(fig, out_path, dpi=160, bbox_inches=\"tight\"):\n",
    "    fig.savefig(out_path, dpi=dpi, bbox_inches=bbox_inches)\n",
    "\n",
    "def plot_and_save_target_sweep(\n",
    "    nT_values, out_prefix,\n",
    "    **kwargs  # passed to sweep_target_samples_vs_errors\n",
    "):\n",
    "    figI, figII, res = sweep_target_samples_vs_errors(nT_values, **kwargs)\n",
    "    save_sweep_results_csv(nT_values, res, \"nT\", f\"{out_prefix}_target.csv\")\n",
    "    save_sweep_results_json(nT_values, res, \"nT\", f\"{out_prefix}_target.json\")\n",
    "    save_figure(figI, f\"{out_prefix}_target_typeI.png\")\n",
    "    save_figure(figII, f\"{out_prefix}_target_typeII.png\")\n",
    "    return res\n",
    "\n",
    "def plot_and_save_source_sweep(\n",
    "    nS_values, out_prefix,\n",
    "    **kwargs  # passed to sweep_source_samples_vs_errors\n",
    "):\n",
    "    figI, figII, res = sweep_source_samples_vs_errors(nS_values, **kwargs)\n",
    "    save_sweep_results_csv(nS_values, res, \"nS\", f\"{out_prefix}_source.csv\")\n",
    "    save_sweep_results_json(nS_values, res, \"nS\", f\"{out_prefix}_source.json\")\n",
    "    save_figure(figI, f\"{out_prefix}_source_typeI.png\")\n",
    "    save_figure(figII, f\"{out_prefix}_source_typeII.png\")\n",
    "    return res\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af268c43",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Source-size sweep on NASA\n",
    "res = plot_and_save_source_sweep(\n",
    "    nS_values=[100, 200, 300, 400, 500, 600, 700, 800, 900, 1000],\n",
    "    out_prefix=\"results_nasa\",\n",
    "    use_gaussian=False,\n",
    "    use_nasa=True,\n",
    "    nasa_d=6,\n",
    "    n0T=40, n1T=40,\n",
    "    Ntest=1700,\n",
    "    model_kind=\"mlp\",\n",
    "    alpha=0.10\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
