{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i_G25OyQqnOR"
      },
      "source": [
        "# Ablation of Calibration\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mGhlvcrUqt8H"
      },
      "source": [
        "## Model Definition"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "WLtpF6y4pMqG"
      },
      "outputs": [],
      "source": [
        "# =========================\n",
        "# Calibration Ablation Top Cell (Anchor + Router, NO CAL step)\n",
        "# =========================\n",
        "import os, math, gc\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "\n",
        "from sklearn.model_selection import train_test_split\n",
        "from sklearn.metrics import mean_squared_error, r2_score\n",
        "from sklearn.ensemble import GradientBoostingRegressor\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "\n",
        "# ---------- Globals ----------\n",
        "LOG2PI  = math.log(2*math.pi)\n",
        "DEVICE  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "torch.set_default_dtype(torch.float32)\n",
        "\n",
        "# ---------- Repro ----------\n",
        "def set_seed(seed: int = 1):\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "        torch.backends.cudnn.deterministic = True\n",
        "        torch.backends.cudnn.benchmark = False\n",
        "\n",
        "# ---------- Utils ----------\n",
        "def rmse_score(y_true, y_pred):\n",
        "    return float(np.sqrt(mean_squared_error(y_true, y_pred)))\n",
        "\n",
        "def zscore_fit(y_train):\n",
        "    my = float(np.mean(y_train))\n",
        "    sy = float(np.std(y_train) + 1e-8)\n",
        "    return my, sy\n",
        "\n",
        "def _topk_mask(w, k=2):\n",
        "    _, topi = torch.topk(w, k, dim=-1)\n",
        "    mask = torch.zeros_like(w).scatter_(-1, topi, 1.0)\n",
        "    w2 = w * mask\n",
        "    return w2 / (w2.sum(dim=-1, keepdim=True) + 1e-12)\n",
        "\n",
        "def _topk_mask_smooth(w, k=2, eps=0.05):\n",
        "    _, topi = torch.topk(w, k, dim=-1)\n",
        "    mask = torch.zeros_like(w).scatter_(-1, topi, 1.0)\n",
        "    w_top = w * mask\n",
        "    w_top = w_top / (w_top.sum(dim=-1, keepdim=True) + 1e-12)\n",
        "    return (1.0 - eps) * w_top + (eps / k) * mask\n",
        "\n",
        "# ---------- Dataset Loader ----------\n",
        "def load_dataset(name: str, verbose: bool = False):\n",
        "    \"\"\"Return (X, y) as float64; numeric clean + drop NaN rows.\"\"\"\n",
        "    def _clean_numeric(df: pd.DataFrame) -> pd.DataFrame:\n",
        "        df = df.apply(pd.to_numeric, errors=\"coerce\")\n",
        "        df = df.replace([np.inf, -np.inf], np.nan).dropna(axis=0)\n",
        "        return df\n",
        "\n",
        "    name = name.lower().strip()\n",
        "\n",
        "    if name == \"housing\":\n",
        "        df = pd.read_csv(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data\",\n",
        "            header=None, sep=r\"\\s+\", engine=\"python\", comment=\"#\", skip_blank_lines=True\n",
        "        )\n",
        "        if df.shape[1] == 1:\n",
        "            df = df[0].astype(str).str.strip().str.split(r\"\\s+\", expand=True)\n",
        "        df = _clean_numeric(df); assert df.shape[1] >= 14\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        if verbose: print(f\"[Housing] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"concrete\":\n",
        "        df = pd.read_excel(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/concrete/compressive/Concrete_Data.xls\"\n",
        "        )\n",
        "        df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        if verbose: print(f\"[Concrete] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"wine\":\n",
        "        df = pd.read_csv(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv\",\n",
        "            delimiter=\";\"\n",
        "        )\n",
        "        df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        if verbose: print(f\"[WineRed] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"energy\":\n",
        "        df = pd.read_excel(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/00242/ENB2012_data.xlsx\"\n",
        "        )\n",
        "        df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-2].to_numpy(np.float64)  # drop Y1,Y2\n",
        "        y = df.iloc[:, -1].to_numpy(np.float64)   # use Y2\n",
        "        if verbose: print(f\"[Energy] X.shape={X.shape} y.shape={y.shape} | y∈[{np.min(y):.3f},{np.max(y):.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"yacht\":\n",
        "        df = pd.read_csv(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/00243/yacht_hydrodynamics.data\",\n",
        "            header=None, sep=r\"\\s+\", engine=\"python\", comment=\"#\", skip_blank_lines=True\n",
        "        )\n",
        "        if df.shape[1] == 1:\n",
        "            df = df[0].astype(str).str.strip().str.split(r\"\\s+\", expand=True)\n",
        "        df = _clean_numeric(df); assert df.shape[1] >= 7\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        if verbose: print(f\"[Yacht] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    # local datasets (optional)\n",
        "    if name == \"power\":\n",
        "        local = \"data/power.xlsx\"\n",
        "        if not os.path.exists(local): raise FileNotFoundError(\"[power] need local 'power.xlsx'\")\n",
        "        df = pd.read_excel(local); df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        if verbose: print(f\"[Power(local)] X.shape={X.shape} y.shape={y.shape}\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"kin8nm\":\n",
        "        local = \"data/kin8nm.csv\"; df = pd.read_csv(local); df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        return X, y\n",
        "\n",
        "    if name == \"msd\":\n",
        "        local = \"data/YearPredictionMSD.txt\"\n",
        "        df = pd.read_csv(local, header=None); df = df.iloc[:, ::-1]; df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        return X, y\n",
        "\n",
        "    if name == \"protein\":\n",
        "        local = \"data/protein.csv\"; df = pd.read_csv(local, sep=None, engine=\"python\"); df = _clean_numeric(df)\n",
        "        y = df.iloc[:, 0].to_numpy(np.float64); X = df.iloc[:, 1:].to_numpy(np.float64)\n",
        "        return X, y\n",
        "\n",
        "    if name == \"naval\":\n",
        "        local = \"data/naval.txt\"; df = pd.read_csv(local, sep=r\"\\s+\", header=None, engine=\"python\"); df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        return X, y\n",
        "\n",
        "    raise ValueError(\"Unknown dataset.\")\n",
        "\n",
        "# ---------- MoE building blocks (Anchor + Router) ----------\n",
        "class Projection(nn.Module):\n",
        "    def __init__(self, d, D):\n",
        "        super().__init__()\n",
        "        self.w = nn.Linear(d, D, bias=True)\n",
        "        nn.init.xavier_uniform_(self.w.weight); nn.init.zeros_(self.w.bias)\n",
        "    def forward(self, x): return self.w(x)\n",
        "\n",
        "class Window(nn.Module):\n",
        "    def __init__(self, K, D, min_log_s=-2.5, max_log_s=1.0):\n",
        "        super().__init__()\n",
        "        self.c         = nn.Parameter(torch.randn(K, D))\n",
        "        self.log_s     = nn.Parameter(torch.zeros(K, D))\n",
        "        self.min_log_s = min_log_s; self.max_log_s = max_log_s\n",
        "    def forward(self, z):\n",
        "        log_s = torch.clamp(self.log_s, min=self.min_log_s, max=self.max_log_s)\n",
        "        diff2 = ((z[:, None] - self.c)**2) / (2 * torch.exp(log_s)**2)\n",
        "        return torch.exp(-diff2.sum(dim=-1)) + 1e-12  # [B,K]\n",
        "\n",
        "class Router(nn.Module):\n",
        "    def __init__(self, D, K):\n",
        "        super().__init__()\n",
        "        self.q   = nn.Linear(D, 64)\n",
        "        self.k   = nn.Parameter(torch.randn(K, 64))\n",
        "        self.tau = 3.0\n",
        "    def forward(self, z):\n",
        "        logits = (self.q(z) @ self.k.T) / math.sqrt(64)\n",
        "        return F.softmax(logits / self.tau, dim=-1)\n",
        "\n",
        "class ExpertMDN(nn.Module):\n",
        "    def __init__(self, d, h, nc, sigma_min=5e-2, sigma_max=1.0, learn_mean=True):\n",
        "        super().__init__()\n",
        "        self.net = nn.Sequential(\n",
        "            nn.Linear(d, h), nn.ReLU(),\n",
        "            nn.Linear(h, h), nn.ReLU()\n",
        "        )\n",
        "        self.logits = nn.Linear(h, nc)\n",
        "        self.learn_mean = learn_mean\n",
        "        if learn_mean:\n",
        "            self.means  = nn.Linear(h, nc)\n",
        "        self.log_sc = nn.Linear(h, nc)\n",
        "        self.sigma_min = float(sigma_min)\n",
        "        self.sigma_max = float(sigma_max)\n",
        "        for m in self.modules():\n",
        "            if isinstance(m, nn.Linear):\n",
        "                nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)\n",
        "    def forward(self, x):\n",
        "        h  = self.net(x)\n",
        "        pi = F.softmax(self.logits(h), dim=-1)\n",
        "        mu = self.means(h) if self.learn_mean else None\n",
        "        sg = torch.exp(self.log_sc(h)).clamp(self.sigma_min, self.sigma_max)\n",
        "        return pi, mu, sg\n",
        "\n",
        "class BLRMoE(nn.Module):\n",
        "    \"\"\"Anchor-MoE with Router; mean_mode in {'anchor','anchor+delta','free'}.\"\"\"\n",
        "    def __init__(self, d, D, K, hid, nc,\n",
        "                 mean_mode='anchor+delta', delta_l2=3e-3,\n",
        "                 w_ent_warm=+1e-3, w_ent_cool=-2e-4,\n",
        "                 l2_win=1e-4, lb_coef=1e-3,\n",
        "                 sigma_min=5e-2, sigma_max=1.0,\n",
        "                 topk=2, smooth_eps=0.05):\n",
        "        super().__init__()\n",
        "        assert mean_mode in ['anchor','anchor+delta','free']\n",
        "        assert 1 <= topk <= K\n",
        "        self.mean_mode = mean_mode\n",
        "        self.delta_l2  = float(delta_l2)\n",
        "        self.proj   = Projection(d, D)\n",
        "        self.win    = Window(K, D)\n",
        "        self.router = Router(D, K)\n",
        "        learn_mean  = (mean_mode != 'anchor')\n",
        "        self.exps   = nn.ModuleList([ExpertMDN(d, hid, nc, sigma_min, sigma_max, learn_mean=learn_mean) for _ in range(K)])\n",
        "        self.w_ent_warm = w_ent_warm\n",
        "        self.w_ent_cool = w_ent_cool\n",
        "        self.l2_win     = l2_win\n",
        "        self.lb_coef    = lb_coef\n",
        "        self.topk       = topk\n",
        "        self.smooth_eps = smooth_eps\n",
        "        self.K = K; self.nc = nc\n",
        "\n",
        "    def _mixture_params(self, X, train=True):\n",
        "        z = self.proj(X)\n",
        "        w = self.win(z) * self.router(z)\n",
        "        w = w / (w.sum(dim=-1, keepdim=True) + 1e-12)\n",
        "        w = _topk_mask_smooth(w, k=self.topk, eps=self.smooth_eps) if train else _topk_mask(w, k=self.topk)\n",
        "\n",
        "        B, K, C = X.size(0), self.K, self.nc\n",
        "        Pi = torch.full((B, K, C), 1.0/C, device=X.device)\n",
        "        Mu = torch.zeros(B, K, C, device=X.device)\n",
        "        Sg = torch.ones(B, K, C, device=X.device)\n",
        "\n",
        "        _, topi = torch.topk(w, self.topk, dim=-1)\n",
        "        uniq = torch.unique(topi)\n",
        "        for j in uniq.tolist():\n",
        "            pi_j, mu_j, sg_j = self.exps[j](X)\n",
        "            mask = (topi == j).any(dim=1).float().unsqueeze(-1)\n",
        "            Pi[:, j, :] = pi_j * mask + Pi[:, j, :]*(1 - mask)\n",
        "            if mu_j is not None:\n",
        "                Mu[:, j, :] = mu_j * mask + Mu[:, j, :]*(1 - mask)\n",
        "            Sg[:, j, :] = sg_j * mask + Sg[:, j, :]*(1 - mask)\n",
        "        return w, Pi, Mu, Sg, z\n",
        "\n",
        "    def nll(self, X, y_z, mu_anchor_z=None, epoch=1, warmup_epochs=150):\n",
        "        train_flag = self.training\n",
        "        w, Pi, Mu, Sg, z = self._mixture_params(X, train=train_flag)\n",
        "\n",
        "        if self.mean_mode == 'anchor':\n",
        "            assert mu_anchor_z is not None\n",
        "            Mu_eff = mu_anchor_z[:, None, None]\n",
        "            delta_pen = 0.0\n",
        "        elif self.mean_mode == 'anchor+delta':\n",
        "            assert mu_anchor_z is not None\n",
        "            Mu_eff = mu_anchor_z[:, None, None] + Mu\n",
        "            delta_pen = (Mu**2).mean() * self.delta_l2\n",
        "        else:\n",
        "            Mu_eff = Mu\n",
        "            delta_pen = 0.0\n",
        "\n",
        "        yv = y_z[:, None, None]\n",
        "        logp = -0.5 * ((yv - Mu_eff)/Sg)**2 - torch.log(Sg) - 0.5*LOG2PI\n",
        "\n",
        "        w3   = w[:, :, None]\n",
        "        logw = torch.where(w3 > 0, torch.log(w3 + 1e-12), torch.full_like(w3, -1e9))\n",
        "        logpi= torch.log(Pi + 1e-12)\n",
        "\n",
        "        logmix = torch.logsumexp(logw + logpi + logp, dim=(1,2))\n",
        "        nll = -logmix.mean()\n",
        "\n",
        "        if train_flag:\n",
        "            w_ent = self.w_ent_warm if epoch <= warmup_epochs else self.w_ent_cool\n",
        "            p = self.router(z)\n",
        "            ent = (p * torch.log(p + 1e-12)).sum(dim=1).mean()\n",
        "            l2w = (self.win.log_s**2).mean()\n",
        "            rho = p.mean(dim=0)\n",
        "            lb_loss = ((rho - 1.0/p.size(1))**2).sum()\n",
        "            nll = nll + w_ent*ent + self.l2_win*l2w + self.lb_coef*lb_loss + delta_pen\n",
        "        return nll\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def predict_mean_var(self, X, mu_anchor_z=None):\n",
        "        w, Pi, Mu, Sg, _ = self._mixture_params(X, train=False)\n",
        "        if self.mean_mode == 'anchor':\n",
        "            assert mu_anchor_z is not None\n",
        "            Mu_eff = mu_anchor_z[:, None, None]\n",
        "        elif self.mean_mode == 'anchor+delta':\n",
        "            assert mu_anchor_z is not None\n",
        "            Mu_eff = mu_anchor_z[:, None, None] + Mu\n",
        "        else:\n",
        "            Mu_eff = Mu\n",
        "\n",
        "        mu_e  = (Pi * Mu_eff).sum(dim=2)              # [B,K]\n",
        "        m2_e  = (Pi * (Sg**2 + Mu_eff**2)).sum(dim=2) # [B,K]\n",
        "        mu_z  = (w * mu_e).sum(dim=1)                 # [B]\n",
        "        second= (w * m2_e).sum(dim=1)                 # [B]\n",
        "        var_z = torch.clamp(second - mu_z**2, min=1e-9)\n",
        "        return mu_z, var_z\n",
        "\n",
        "# ---------- One-split training (NO CALIBRATION) ----------\n",
        "def train_one_split_no_cal(\n",
        "    X_all, y_all, X_te, y_te,\n",
        "    standardize_x=True,\n",
        "    D=2, K=8, HID=128, NC=3,\n",
        "    LR=1e-3, EPOCHS=400,\n",
        "    MEAN_MODE='anchor+delta',\n",
        "    DELTA_L2=3e-3,\n",
        "    SIGMA_MIN=5e-2, SIGMA_MAX=1.0,\n",
        "    TOPK=2, SMOOTH_EPS=0.05,\n",
        "    seed=1, val_frac=0.2\n",
        "):\n",
        "    \n",
        "    rng = np.random.RandomState(seed)\n",
        "\n",
        "    X_tv, y_tv = X_all, y_all\n",
        "    X_tr, X_va, y_tr, y_va = train_test_split(\n",
        "        X_tv, y_tv, test_size=val_frac, random_state=seed\n",
        "    )\n",
        "\n",
        "    gbdt_big = GradientBoostingRegressor(\n",
        "        n_estimators=2000, learning_rate=0.05, max_depth=3,\n",
        "        subsample=1.0, random_state=seed\n",
        "    ).fit(X_tr, y_tr)\n",
        "\n",
        "    va_curve = [rmse_score(y_va, p) for p in gbdt_big.staged_predict(X_va)]\n",
        "    best_it  = int(np.argmin(va_curve)) + 1\n",
        "\n",
        "    gbdt_sub   = GradientBoostingRegressor(\n",
        "        n_estimators=best_it, learning_rate=0.05, max_depth=3,\n",
        "        subsample=1.0, random_state=seed\n",
        "    ).fit(X_tr, y_tr)\n",
        "\n",
        "    gbdt_final = GradientBoostingRegressor(\n",
        "        n_estimators=best_it, learning_rate=0.05, max_depth=3,\n",
        "        subsample=1.0, random_state=seed\n",
        "    ).fit(X_tv, y_tv)\n",
        "\n",
        "    my_tr, sy_tr = zscore_fit(y_tr)\n",
        "    my_tv, sy_tv = zscore_fit(y_tv)\n",
        "\n",
        "    mu_tr_anchor = (gbdt_sub.predict(X_tr) - my_tr) / sy_tr\n",
        "    mu_va_anchor = (gbdt_sub.predict(X_va) - my_tr) / sy_tr\n",
        "    X_tr_aug = np.column_stack([X_tr, mu_tr_anchor])\n",
        "    X_va_aug = np.column_stack([X_va, mu_va_anchor])\n",
        "\n",
        "    if standardize_x:\n",
        "        mx_tr = X_tr_aug.mean(0, keepdims=True)\n",
        "        sx_tr = X_tr_aug.std(0, keepdims=True) + 1e-8\n",
        "    else:\n",
        "        mx_tr = np.zeros((1, X_tr_aug.shape[1])); sx_tr = np.ones((1, X_tr_aug.shape[1]))\n",
        "\n",
        "    def to_tensor(x, yz=None, mx=None, sx=None):\n",
        "        Xt = torch.tensor(((x - mx)/sx).astype(np.float32), device=DEVICE)\n",
        "        if yz is None: return Xt\n",
        "        Yt = torch.tensor(yz.astype(np.float32), device=DEVICE)\n",
        "        return Xt, Yt\n",
        "\n",
        "    y_tr_z = (y_tr - my_tr) / sy_tr\n",
        "    y_va_z = (y_va - my_tr) / sy_tr\n",
        "\n",
        "    model = BLRMoE(  \n",
        "        d=X_tr_aug.shape[1], D=D, K=K, hid=HID, nc=NC,\n",
        "        mean_mode=MEAN_MODE, delta_l2=DELTA_L2,\n",
        "        sigma_min=SIGMA_MIN, sigma_max=SIGMA_MAX,\n",
        "        topk=TOPK, smooth_eps=SMOOTH_EPS\n",
        "    ).to(DEVICE)\n",
        "    opt = optim.AdamW(model.parameters(), lr=LR, weight_decay=3e-4)\n",
        "\n",
        "    Xtr_t, ytr_t = to_tensor(X_tr_aug, y_tr_z, mx_tr, sx_tr)\n",
        "    Xva_t, yva_t = to_tensor(X_va_aug, y_va_z, mx_tr, sx_tr)\n",
        "    mu_tr_t = torch.tensor(mu_tr_anchor.astype(np.float32), device=DEVICE)\n",
        "    mu_va_t = torch.tensor(mu_va_anchor.astype(np.float32), device=DEVICE)\n",
        "\n",
        "    best_ep, best_vnll, best_state = 0, +1e9, None\n",
        "    for ep in range(1, EPOCHS+1):\n",
        "        model.train(); opt.zero_grad()\n",
        "        loss = model.nll(Xtr_t, ytr_t, mu_anchor_z=mu_tr_t, epoch=ep, warmup_epochs=150)\n",
        "        loss.backward(); nn.utils.clip_grad_norm_(model.parameters(), 2.0); opt.step()\n",
        "        model.router.tau = max(model.router.tau * 0.995, 1.0)\n",
        "\n",
        "        model.eval()\n",
        "        with torch.no_grad():\n",
        "            vnll = float(model.nll(Xva_t, yva_t, mu_anchor_z=mu_va_t).cpu().item())\n",
        "        if vnll < best_vnll:\n",
        "            best_vnll = vnll; best_ep = ep\n",
        "            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}\n",
        "\n",
        "    mu_tv_anchor = (gbdt_final.predict(X_tv) - my_tv) / sy_tv\n",
        "    mu_te_anchor = (gbdt_final.predict(X_te) - my_tv) / sy_tv\n",
        "\n",
        "    X_tv_aug = np.column_stack([X_tv, mu_tv_anchor])\n",
        "    X_te_aug = np.column_stack([X_te, mu_te_anchor])\n",
        "\n",
        "    if standardize_x:\n",
        "        mx_tv = X_tv_aug.mean(0, keepdims=True)\n",
        "        sx_tv = X_tv_aug.std(0, keepdims=True) + 1e-8\n",
        "    else:\n",
        "        mx_tv = np.zeros((1, X_tv_aug.shape[1])); sx_tv = np.ones((1, X_tv_aug.shape[1]))\n",
        "\n",
        "    y_tv_z = (y_tv - my_tv) / sy_tv\n",
        "    y_te_z = (y_te - my_tv) / sy_tv\n",
        "\n",
        "    model.load_state_dict(best_state)\n",
        "    model.train()\n",
        "    opt = optim.AdamW(model.parameters(), lr=LR, weight_decay=3e-4)\n",
        "\n",
        "    Xtv_t, ytv_t = to_tensor(X_tv_aug, y_tv_z, mx_tv, sx_tv)\n",
        "    Xte_t        = to_tensor(X_te_aug, None,   mx_tv, sx_tv)\n",
        "    mu_tv_t  = torch.tensor(mu_tv_anchor.astype(np.float32), device=DEVICE)\n",
        "    mu_te_t  = torch.tensor(mu_te_anchor.astype(np.float32), device=DEVICE)\n",
        "\n",
        "    for ep in range(1, best_ep+1):\n",
        "        opt.zero_grad()\n",
        "        loss = model.nll(Xtv_t, ytv_t, mu_anchor_z=mu_tv_t, epoch=ep, warmup_epochs=150)\n",
        "        loss.backward(); nn.utils.clip_grad_norm_(model.parameters(), 2.0); opt.step()\n",
        "        model.router.tau = max(model.router.tau * 0.995, 1.0)\n",
        "\n",
        "    model.eval()\n",
        "    with torch.no_grad():\n",
        "        yte_t = torch.tensor(y_te_z.astype(np.float32), device=DEVICE)\n",
        "        test_nll = float(model.nll(Xte_t, yte_t, mu_anchor_z=mu_te_t).cpu().item())\n",
        "\n",
        "        mu_z_te, _ = model.predict_mean_var(Xte_t, mu_anchor_z=mu_te_t)\n",
        "        mu_te_orig = mu_z_te.cpu().numpy().astype(np.float64) * sy_tv + my_tv  \n",
        "\n",
        "    rmse = rmse_score(y_te.astype(np.float64), mu_te_orig)\n",
        "    r2   = float(r2_score(y_te.astype(np.float64), mu_te_orig))\n",
        "    return rmse, test_nll, best_ep, r2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IK_uRpyDq32v"
      },
      "source": [
        "\n",
        "\n",
        "## Housing\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "EbclDbcCqfgS",
        "outputId": "e0501ff0-5768-459d-a424-5fc5c46286bc"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Housing] X.shape=(506, 13) y.shape=(506,) | y∈[5.000,50.000]\n",
            "[01/20] MoE best_ep= 14  TestRMSE(orig)=2.5180  TestNLL(z)=0.2402  R²=0.8867\n",
            "[02/20] MoE best_ep= 14  TestRMSE(orig)=2.5395  TestNLL(z)=0.4070  R²=0.8973\n",
            "[03/20] MoE best_ep= 10  TestRMSE(orig)=2.5489  TestNLL(z)=0.2886  R²=0.9212\n",
            "[04/20] MoE best_ep= 10  TestRMSE(orig)=2.9835  TestNLL(z)=0.4587  R²=0.9087\n",
            "[05/20] MoE best_ep= 12  TestRMSE(orig)=3.0009  TestNLL(z)=0.6621  R²=0.9293\n",
            "[06/20] MoE best_ep= 11  TestRMSE(orig)=2.5941  TestNLL(z)=0.3232  R²=0.9172\n",
            "[07/20] MoE best_ep= 10  TestRMSE(orig)=1.9909  TestNLL(z)=0.4026  R²=0.8782\n",
            "[08/20] MoE best_ep= 11  TestRMSE(orig)=2.9305  TestNLL(z)=0.4421  R²=0.8788\n",
            "[09/20] MoE best_ep= 13  TestRMSE(orig)=3.0064  TestNLL(z)=1.1618  R²=0.8916\n",
            "[10/20] MoE best_ep=  9  TestRMSE(orig)=4.2192  TestNLL(z)=0.8531  R²=0.8068\n",
            "[11/20] MoE best_ep= 10  TestRMSE(orig)=2.9192  TestNLL(z)=0.7322  R²=0.9136\n",
            "[12/20] MoE best_ep= 10  TestRMSE(orig)=2.1877  TestNLL(z)=0.1980  R²=0.9301\n",
            "[13/20] MoE best_ep= 14  TestRMSE(orig)=2.4697  TestNLL(z)=0.5157  R²=0.9092\n",
            "[14/20] MoE best_ep= 12  TestRMSE(orig)=3.0828  TestNLL(z)=0.8499  R²=0.9106\n",
            "[15/20] MoE best_ep= 12  TestRMSE(orig)=2.8455  TestNLL(z)=0.4888  R²=0.8934\n",
            "[16/20] MoE best_ep=  8  TestRMSE(orig)=2.7269  TestNLL(z)=0.3832  R²=0.9237\n",
            "[17/20] MoE best_ep=  9  TestRMSE(orig)=2.8015  TestNLL(z)=0.2637  R²=0.8602\n",
            "[18/20] MoE best_ep= 10  TestRMSE(orig)=2.7915  TestNLL(z)=0.6940  R²=0.8930\n",
            "[19/20] MoE best_ep= 14  TestRMSE(orig)=2.3286  TestNLL(z)=0.6053  R²=0.9376\n",
            "[20/20] MoE best_ep= 11  TestRMSE(orig)=2.5028  TestNLL(z)=0.4085  R²=0.9417\n",
            "\n",
            "== Anchor-MoE (Calibration Ablation: NO CAL) on Housing ==\n",
            "RMSE (orig) = 2.7494 ± 0.1013\n",
            "NLL  (z)    = 0.5189 ± 0.0548\n",
            " R²         = 0.9014 ± 0.0069\n"
          ]
        }
      ],
      "source": [
        "# === Run on Housing — Calibration Ablation (NO post-hoc CAL), leak-free ===\n",
        "set_seed(1)\n",
        "X, y = load_dataset(\"housing\", verbose=True)\n",
        "\n",
        "SEED = 1\n",
        "n = X.shape[0]\n",
        "rng = np.random.RandomState(SEED)\n",
        "\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    tr = perm[: round(n * 0.9)]\n",
        "    te = perm[round(n * 0.9):]\n",
        "    splits.append((tr, te))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS     = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS     = 2, 0.05\n",
        "\n",
        "rmses, nlls, r2s = [], [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_no_cal(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll); r2s.append(r2)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "# 汇总\n",
        "rmses = np.asarray(rmses, dtype=np.float64)\n",
        "nlls  = np.asarray(nlls,  dtype=np.float64)\n",
        "r2s   = np.asarray(r2s,   dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "\n",
        "print(\"\\n== Anchor-MoE (Calibration Ablation: NO CAL) on Housing ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")\n",
        "print(f\" R²         = {r2s.mean():.4f} ± {se(r2s):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RPaRJvis8AY7"
      },
      "source": [
        "## Concrete"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YdoM5Rj54JbN",
        "outputId": "92ea22f3-6061-41cf-d088-4703f1f06025"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Concrete] X.shape=(1030, 8) y.shape=(1030,) | y∈[2.332,82.599]\n",
            "[01/20] MoE best_ep= 18  TestRMSE(orig)=4.7647  TestNLL(z)=0.4434  R²=0.9259\n",
            "[02/20] MoE best_ep= 17  TestRMSE(orig)=3.4393  TestNLL(z)=-0.0077  R²=0.9541\n",
            "[03/20] MoE best_ep= 18  TestRMSE(orig)=2.8907  TestNLL(z)=-0.1717  R²=0.9716\n",
            "[04/20] MoE best_ep= 15  TestRMSE(orig)=4.4722  TestNLL(z)=0.1325  R²=0.9272\n",
            "[05/20] MoE best_ep= 13  TestRMSE(orig)=4.6924  TestNLL(z)=0.2881  R²=0.9080\n",
            "[06/20] MoE best_ep= 15  TestRMSE(orig)=3.6997  TestNLL(z)=-0.0508  R²=0.9481\n",
            "[07/20] MoE best_ep= 16  TestRMSE(orig)=4.2595  TestNLL(z)=0.4259  R²=0.9384\n",
            "[08/20] MoE best_ep= 16  TestRMSE(orig)=4.8239  TestNLL(z)=0.0243  R²=0.9116\n",
            "[09/20] MoE best_ep= 16  TestRMSE(orig)=4.1273  TestNLL(z)=0.1597  R²=0.9415\n",
            "[10/20] MoE best_ep= 16  TestRMSE(orig)=4.7230  TestNLL(z)=0.3897  R²=0.9230\n",
            "[11/20] MoE best_ep= 15  TestRMSE(orig)=3.6446  TestNLL(z)=-0.0125  R²=0.9452\n",
            "[12/20] MoE best_ep= 12  TestRMSE(orig)=4.4414  TestNLL(z)=0.1453  R²=0.9263\n",
            "[13/20] MoE best_ep= 12  TestRMSE(orig)=3.6679  TestNLL(z)=0.0789  R²=0.9533\n",
            "[14/20] MoE best_ep= 16  TestRMSE(orig)=3.7083  TestNLL(z)=-0.0745  R²=0.9447\n",
            "[15/20] MoE best_ep= 16  TestRMSE(orig)=4.4655  TestNLL(z)=0.3172  R²=0.9030\n",
            "[16/20] MoE best_ep= 17  TestRMSE(orig)=4.8842  TestNLL(z)=0.2753  R²=0.8911\n",
            "[17/20] MoE best_ep= 17  TestRMSE(orig)=3.9608  TestNLL(z)=0.3610  R²=0.9389\n",
            "[18/20] MoE best_ep= 16  TestRMSE(orig)=4.2852  TestNLL(z)=-0.0207  R²=0.9284\n",
            "[19/20] MoE best_ep= 18  TestRMSE(orig)=4.7694  TestNLL(z)=0.8391  R²=0.9235\n",
            "[20/20] MoE best_ep= 16  TestRMSE(orig)=3.8921  TestNLL(z)=0.5292  R²=0.9369\n",
            "\n",
            "== Anchor-MoE (Calibration Ablation: NO CAL) on Housing ==\n",
            "RMSE (orig) = 4.1806 ± 0.1231\n",
            "NLL  (z)    = 0.2036 ± 0.0558\n",
            " R²         = 0.9320 ± 0.0043\n"
          ]
        }
      ],
      "source": [
        "# === Run on Housing — Calibration Ablation (NO post-hoc CAL), leak-free ===\n",
        "set_seed(1)\n",
        "X, y = load_dataset(\"concrete\", verbose=True)\n",
        "\n",
        "SEED = 1\n",
        "n = X.shape[0]\n",
        "rng = np.random.RandomState(SEED)\n",
        "\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    tr = perm[: round(n * 0.9)]\n",
        "    te = perm[round(n * 0.9):]\n",
        "    splits.append((tr, te))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS     = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS     = 2, 0.05\n",
        "\n",
        "rmses, nlls, r2s = [], [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_no_cal(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll); r2s.append(r2)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses = np.asarray(rmses, dtype=np.float64)\n",
        "nlls  = np.asarray(nlls,  dtype=np.float64)\n",
        "r2s   = np.asarray(r2s,   dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "\n",
        "print(\"\\n== Anchor-MoE (Calibration Ablation: NO CAL) on Housing ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")\n",
        "print(f\" R²         = {r2s.mean():.4f} ± {se(r2s):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J8mnDVZ58G7n"
      },
      "source": [
        "## Energy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "rwMNaohw8JnP",
        "outputId": "ba4800df-9899-40a4-849b-13c987994b25"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Energy] X.shape=(768, 8) y.shape=(768,) | y∈[10.900,48.030]\n",
            "[01/20] MoE best_ep= 61  TestRMSE(orig)=0.9426  TestNLL(z)=-0.9539  R²=0.9904\n",
            "[02/20] MoE best_ep= 96  TestRMSE(orig)=0.9832  TestNLL(z)=-0.9238  R²=0.9906\n",
            "[03/20] MoE best_ep=169  TestRMSE(orig)=0.8509  TestNLL(z)=-1.2703  R²=0.9918\n",
            "[04/20] MoE best_ep=331  TestRMSE(orig)=0.9713  TestNLL(z)=-0.9051  R²=0.9868\n",
            "[05/20] MoE best_ep=347  TestRMSE(orig)=1.1201  TestNLL(z)=-0.6470  R²=0.9867\n",
            "[06/20] MoE best_ep=120  TestRMSE(orig)=1.0670  TestNLL(z)=-1.0686  R²=0.9860\n",
            "[07/20] MoE best_ep=350  TestRMSE(orig)=1.0840  TestNLL(z)=-0.4402  R²=0.9891\n",
            "[08/20] MoE best_ep=336  TestRMSE(orig)=0.7901  TestNLL(z)=-1.2204  R²=0.9937\n",
            "[09/20] MoE best_ep= 33  TestRMSE(orig)=0.9056  TestNLL(z)=-0.9674  R²=0.9902\n",
            "[10/20] MoE best_ep=121  TestRMSE(orig)=0.9693  TestNLL(z)=-0.9573  R²=0.9894\n",
            "[11/20] MoE best_ep=381  TestRMSE(orig)=1.1105  TestNLL(z)=-0.6504  R²=0.9867\n",
            "[12/20] MoE best_ep= 41  TestRMSE(orig)=0.8842  TestNLL(z)=-1.2069  R²=0.9911\n",
            "[13/20] MoE best_ep=238  TestRMSE(orig)=1.2827  TestNLL(z)=-1.0035  R²=0.9832\n",
            "[14/20] MoE best_ep= 57  TestRMSE(orig)=1.1051  TestNLL(z)=-0.8121  R²=0.9859\n",
            "[15/20] MoE best_ep= 42  TestRMSE(orig)=1.1193  TestNLL(z)=-0.9558  R²=0.9861\n",
            "[16/20] MoE best_ep=293  TestRMSE(orig)=1.0002  TestNLL(z)=-0.9168  R²=0.9895\n",
            "[17/20] MoE best_ep= 33  TestRMSE(orig)=1.2355  TestNLL(z)=-0.8233  R²=0.9845\n",
            "[18/20] MoE best_ep= 53  TestRMSE(orig)=0.8531  TestNLL(z)=-1.2368  R²=0.9916\n",
            "[19/20] MoE best_ep= 42  TestRMSE(orig)=1.0761  TestNLL(z)=-1.0983  R²=0.9836\n",
            "[20/20] MoE best_ep= 97  TestRMSE(orig)=0.9167  TestNLL(z)=-1.1135  R²=0.9908\n",
            "\n",
            "== Anchor-MoE (Calibration Ablation: NO CAL) on Housing ==\n",
            "RMSE (orig) = 1.0134 ± 0.0292\n",
            "NLL  (z)    = -0.9586 ± 0.0478\n",
            " R²         = 0.9884 ± 0.0007\n"
          ]
        }
      ],
      "source": [
        "# === Run on Housing — Calibration Ablation (NO post-hoc CAL), leak-free ===\n",
        "set_seed(1)\n",
        "X, y = load_dataset(\"energy\", verbose=True)\n",
        "\n",
        "SEED = 1\n",
        "n = X.shape[0]\n",
        "rng = np.random.RandomState(SEED)\n",
        "\n",
        "# 20 × (90/10) 外层划分\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    tr = perm[: round(n * 0.9)]\n",
        "    te = perm[round(n * 0.9):]\n",
        "    splits.append((tr, te))\n",
        "\n",
        "# 超参（与 Anchor-MoE 主设定一致）\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS     = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS     = 2, 0.05\n",
        "\n",
        "rmses, nlls, r2s = [], [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_no_cal(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll); r2s.append(r2)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "# 汇总\n",
        "rmses = np.asarray(rmses, dtype=np.float64)\n",
        "nlls  = np.asarray(nlls,  dtype=np.float64)\n",
        "r2s   = np.asarray(r2s,   dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "\n",
        "print(\"\\n== Anchor-MoE (Calibration Ablation: NO CAL) on Housing ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")\n",
        "print(f\" R²         = {r2s.mean():.4f} ± {se(r2s):.4f}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "J9VIerQ0_QCJ",
        "outputId": "e7228814-26c1-48b5-afda-c08ded79bafa"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[01/20] MoE best_ep=  6  TestRMSE(orig)=0.1515  TestNLL(z)=0.9876  R²=0.6823\n",
            "[02/20] MoE best_ep=  7  TestRMSE(orig)=0.1475  TestNLL(z)=0.9942  R²=0.6921\n",
            "[03/20] MoE best_ep=  7  TestRMSE(orig)=0.1483  TestNLL(z)=0.9738  R²=0.6708\n",
            "[04/20] MoE best_ep=  7  TestRMSE(orig)=0.1422  TestNLL(z)=0.9972  R²=0.7104\n",
            "[05/20] MoE best_ep=  5  TestRMSE(orig)=0.1492  TestNLL(z)=0.9190  R²=0.6363\n",
            "[06/20] MoE best_ep=  7  TestRMSE(orig)=0.1474  TestNLL(z)=1.0024  R²=0.7023\n",
            "[07/20] MoE best_ep=  7  TestRMSE(orig)=0.1517  TestNLL(z)=1.0784  R²=0.6434\n",
            "[08/20] MoE best_ep=  6  TestRMSE(orig)=0.1421  TestNLL(z)=0.9413  R²=0.7363\n",
            "[09/20] MoE best_ep=  7  TestRMSE(orig)=0.1370  TestNLL(z)=0.8412  R²=0.7247\n",
            "[10/20] MoE best_ep=  6  TestRMSE(orig)=0.1500  TestNLL(z)=0.9915  R²=0.6860\n",
            "[11/20] MoE best_ep=  7  TestRMSE(orig)=0.1423  TestNLL(z)=0.9567  R²=0.7242\n",
            "[12/20] MoE best_ep=  7  TestRMSE(orig)=0.1473  TestNLL(z)=1.0254  R²=0.6846\n",
            "[13/20] MoE best_ep=  7  TestRMSE(orig)=0.1482  TestNLL(z)=0.9192  R²=0.7016\n",
            "[14/20] MoE best_ep=  9  TestRMSE(orig)=0.1346  TestNLL(z)=0.9408  R²=0.7358\n",
            "[15/20] MoE best_ep=  6  TestRMSE(orig)=0.1383  TestNLL(z)=0.8970  R²=0.7308\n",
            "[16/20] MoE best_ep=  6  TestRMSE(orig)=0.1477  TestNLL(z)=0.9704  R²=0.6812\n",
            "[17/20] MoE best_ep=  6  TestRMSE(orig)=0.1415  TestNLL(z)=0.8946  R²=0.7151\n",
            "[18/20] MoE best_ep=  6  TestRMSE(orig)=0.1491  TestNLL(z)=0.9989  R²=0.6691\n",
            "[19/20] MoE best_ep=  7  TestRMSE(orig)=0.1452  TestNLL(z)=1.0453  R²=0.7138\n",
            "[20/20] MoE best_ep=  6  TestRMSE(orig)=0.1519  TestNLL(z)=1.0397  R²=0.6753\n",
            "\n",
            "== Anchor-MoE (Calibration Ablation: NO CAL) on Housing ==\n",
            "RMSE (orig) = 0.1457 ± 0.0011\n",
            "NLL  (z)    = 0.9707 ± 0.0130\n",
            " R²         = 0.6958 ± 0.0064\n"
          ]
        }
      ],
      "source": [
        "# === Run on Housing — Calibration Ablation (NO post-hoc CAL), leak-free ===\n",
        "set_seed(1)\n",
        "X, y = load_dataset(\"kin8nm\", verbose=True)\n",
        "\n",
        "SEED = 1\n",
        "n = X.shape[0]\n",
        "rng = np.random.RandomState(SEED)\n",
        "\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    tr = perm[: round(n * 0.9)]\n",
        "    te = perm[round(n * 0.9):]\n",
        "    splits.append((tr, te))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS     = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS     = 2, 0.05\n",
        "\n",
        "rmses, nlls, r2s = [], [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_no_cal(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll); r2s.append(r2)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "# 汇总\n",
        "rmses = np.asarray(rmses, dtype=np.float64)\n",
        "nlls  = np.asarray(nlls,  dtype=np.float64)\n",
        "r2s   = np.asarray(r2s,   dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "\n",
        "print(\"\\n== Anchor-MoE (Calibration Ablation: NO CAL) on Housing ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")\n",
        "print(f\" R²         = {r2s.mean():.4f} ± {se(r2s):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LUZiODvWYPRN"
      },
      "source": [
        "## Naval"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "nQg9HxzWYRPp",
        "outputId": "c1369c56-70c2-42a0-ef9f-a9f9c9abc00d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[01/20] MoE best_ep=279  TestRMSE(orig)=0.0008  TestNLL(z)=-0.9827  R²=0.9894\n",
            "[02/20] MoE best_ep=149  TestRMSE(orig)=0.0005  TestNLL(z)=-1.2423  R²=0.9950\n",
            "[03/20] MoE best_ep=304  TestRMSE(orig)=0.0006  TestNLL(z)=-0.9949  R²=0.9933\n",
            "[04/20] MoE best_ep=229  TestRMSE(orig)=0.0006  TestNLL(z)=-1.3553  R²=0.9943\n",
            "[05/20] MoE best_ep=279  TestRMSE(orig)=0.0006  TestNLL(z)=-1.0081  R²=0.9935\n",
            "[06/20] MoE best_ep=146  TestRMSE(orig)=0.0006  TestNLL(z)=-1.2442  R²=0.9939\n",
            "[07/20] MoE best_ep= 90  TestRMSE(orig)=0.0006  TestNLL(z)=-1.0729  R²=0.9930\n",
            "[08/20] MoE best_ep=103  TestRMSE(orig)=0.0005  TestNLL(z)=-1.2429  R²=0.9949\n",
            "[09/20] MoE best_ep=183  TestRMSE(orig)=0.0006  TestNLL(z)=-1.0587  R²=0.9934\n",
            "[10/20] MoE best_ep=104  TestRMSE(orig)=0.0006  TestNLL(z)=-1.1893  R²=0.9932\n",
            "[11/20] MoE best_ep=195  TestRMSE(orig)=0.0006  TestNLL(z)=-0.9942  R²=0.9939\n",
            "[12/20] MoE best_ep=372  TestRMSE(orig)=0.0005  TestNLL(z)=-1.1910  R²=0.9945\n",
            "[13/20] MoE best_ep=242  TestRMSE(orig)=0.0006  TestNLL(z)=-1.1072  R²=0.9927\n",
            "[14/20] MoE best_ep=373  TestRMSE(orig)=0.0005  TestNLL(z)=-1.2104  R²=0.9948\n",
            "[15/20] MoE best_ep=368  TestRMSE(orig)=0.0006  TestNLL(z)=-1.2374  R²=0.9945\n",
            "[16/20] MoE best_ep=138  TestRMSE(orig)=0.0006  TestNLL(z)=-1.0384  R²=0.9930\n",
            "[17/20] MoE best_ep= 63  TestRMSE(orig)=0.0005  TestNLL(z)=-1.2221  R²=0.9943\n",
            "[18/20] MoE best_ep=350  TestRMSE(orig)=0.0006  TestNLL(z)=-0.8690  R²=0.9924\n",
            "[19/20] MoE best_ep=259  TestRMSE(orig)=0.0006  TestNLL(z)=-1.1445  R²=0.9942\n",
            "[20/20] MoE best_ep= 75  TestRMSE(orig)=0.0006  TestNLL(z)=-1.0773  R²=0.9931\n",
            "\n",
            "== Anchor-MoE (Calibration Ablation: NO CAL) on Housing ==\n",
            "RMSE (orig) = 0.0006 ± 0.0000\n",
            "NLL  (z)    = -1.1241 ± 0.0273\n",
            " R²         = 0.9936 ± 0.0003\n"
          ]
        }
      ],
      "source": [
        "# === Run on Housing — Calibration Ablation (NO post-hoc CAL), leak-free ===\n",
        "set_seed(1)\n",
        "X, y = load_dataset(\"naval\", verbose=True)\n",
        "\n",
        "SEED = 1\n",
        "n = X.shape[0]\n",
        "rng = np.random.RandomState(SEED)\n",
        "\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    tr = perm[: round(n * 0.9)]\n",
        "    te = perm[round(n * 0.9):]\n",
        "    splits.append((tr, te))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS     = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS     = 2, 0.05\n",
        "\n",
        "rmses, nlls, r2s = [], [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_no_cal(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll); r2s.append(r2)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "# 汇总\n",
        "rmses = np.asarray(rmses, dtype=np.float64)\n",
        "nlls  = np.asarray(nlls,  dtype=np.float64)\n",
        "r2s   = np.asarray(r2s,   dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "\n",
        "print(\"\\n== Anchor-MoE (Calibration Ablation: NO CAL) on Housing ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")\n",
        "print(f\" R²         = {r2s.mean():.4f} ± {se(r2s):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qE_beSG4xziN"
      },
      "source": [
        "## Power"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "IllKgkXox1IK",
        "outputId": "9ab8baca-8022-420b-a0fc-ea7df08773d7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Power(local)] X.shape=(9568, 4) y.shape=(9568,)\n",
            "[01/20] MoE best_ep=143  TestRMSE(orig)=3.1920  TestNLL(z)=-0.0919  R²=0.9633\n",
            "[02/20] MoE best_ep= 22  TestRMSE(orig)=3.2945  TestNLL(z)=-0.0917  R²=0.9634\n",
            "[03/20] MoE best_ep= 66  TestRMSE(orig)=2.9467  TestNLL(z)=-0.2605  R²=0.9713\n",
            "[04/20] MoE best_ep= 18  TestRMSE(orig)=2.9708  TestNLL(z)=-0.2593  R²=0.9677\n",
            "[05/20] MoE best_ep= 56  TestRMSE(orig)=3.6666  TestNLL(z)=-0.0787  R²=0.9553\n",
            "[06/20] MoE best_ep= 59  TestRMSE(orig)=3.4606  TestNLL(z)=-0.0649  R²=0.9591\n",
            "[07/20] MoE best_ep= 38  TestRMSE(orig)=3.0472  TestNLL(z)=-0.2186  R²=0.9674\n",
            "[08/20] MoE best_ep=102  TestRMSE(orig)=2.9472  TestNLL(z)=0.0552  R²=0.9706\n",
            "[09/20] MoE best_ep= 23  TestRMSE(orig)=3.3879  TestNLL(z)=-0.1962  R²=0.9615\n",
            "[10/20] MoE best_ep= 20  TestRMSE(orig)=3.2815  TestNLL(z)=-0.2119  R²=0.9623\n",
            "[11/20] MoE best_ep= 24  TestRMSE(orig)=2.9981  TestNLL(z)=-0.2023  R²=0.9684\n",
            "[12/20] MoE best_ep= 18  TestRMSE(orig)=3.2945  TestNLL(z)=-0.0847  R²=0.9636\n",
            "[13/20] MoE best_ep= 81  TestRMSE(orig)=2.9678  TestNLL(z)=-0.2796  R²=0.9704\n",
            "[14/20] MoE best_ep= 60  TestRMSE(orig)=2.9096  TestNLL(z)=-0.2087  R²=0.9706\n",
            "[15/20] MoE best_ep= 64  TestRMSE(orig)=3.2246  TestNLL(z)=-0.2910  R²=0.9628\n",
            "[16/20] MoE best_ep= 21  TestRMSE(orig)=2.7918  TestNLL(z)=-0.2918  R²=0.9736\n",
            "[17/20] MoE best_ep= 55  TestRMSE(orig)=2.9316  TestNLL(z)=-0.2773  R²=0.9705\n",
            "[18/20] MoE best_ep= 52  TestRMSE(orig)=3.4077  TestNLL(z)=-0.0709  R²=0.9610\n",
            "[19/20] MoE best_ep= 66  TestRMSE(orig)=3.3915  TestNLL(z)=-0.1979  R²=0.9613\n",
            "[20/20] MoE best_ep= 20  TestRMSE(orig)=2.9984  TestNLL(z)=-0.2327  R²=0.9684\n",
            "\n",
            "== Anchor-MoE (Calibration Ablation: NO CAL) on Housing ==\n",
            "RMSE (orig) = 3.1555 ± 0.0525\n",
            "NLL  (z)    = -0.1778 ± 0.0217\n",
            " R²         = 0.9656 ± 0.0011\n"
          ]
        }
      ],
      "source": [
        "# === Run on Housing — Calibration Ablation (NO post-hoc CAL), leak-free ===\n",
        "set_seed(1)\n",
        "X, y = load_dataset(\"power\", verbose=True)\n",
        "\n",
        "SEED = 1\n",
        "n = X.shape[0]\n",
        "rng = np.random.RandomState(SEED)\n",
        "\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    tr = perm[: round(n * 0.9)]\n",
        "    te = perm[round(n * 0.9):]\n",
        "    splits.append((tr, te))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS     = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS     = 2, 0.05\n",
        "\n",
        "rmses, nlls, r2s = [], [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_no_cal(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll); r2s.append(r2)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses = np.asarray(rmses, dtype=np.float64)\n",
        "nlls  = np.asarray(nlls,  dtype=np.float64)\n",
        "r2s   = np.asarray(r2s,   dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "\n",
        "print(\"\\n== Anchor-MoE (Calibration Ablation: NO CAL) on Housing ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")\n",
        "print(f\" R²         = {r2s.mean():.4f} ± {se(r2s):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SzXxINBF-Lug"
      },
      "source": [
        "## Protein"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Le-I547Bx_OF",
        "outputId": "86ce5a84-33db-4027-88c2-b17537166210"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[01/20] MoE best_ep=116  TestRMSE(orig)=4.4435  TestNLL(z)=1.0201  R²=0.4890\n",
            "[02/20] MoE best_ep=  6  TestRMSE(orig)=4.3157  TestNLL(z)=1.0940  R²=0.5116\n",
            "[03/20] MoE best_ep= 24  TestRMSE(orig)=4.2766  TestNLL(z)=0.9442  R²=0.4970\n",
            "[04/20] MoE best_ep= 73  TestRMSE(orig)=4.3144  TestNLL(z)=0.7369  R²=0.4867\n",
            "[05/20] MoE best_ep= 84  TestRMSE(orig)=4.4873  TestNLL(z)=0.9044  R²=0.4712\n",
            "[06/20] MoE best_ep= 88  TestRMSE(orig)=4.6044  TestNLL(z)=1.0003  R²=0.4451\n",
            "[07/20] MoE best_ep= 69  TestRMSE(orig)=4.2610  TestNLL(z)=0.8132  R²=0.5151\n",
            "[08/20] MoE best_ep= 74  TestRMSE(orig)=4.3224  TestNLL(z)=0.8632  R²=0.5149\n",
            "[09/20] MoE best_ep=  5  TestRMSE(orig)=4.3314  TestNLL(z)=1.1174  R²=0.5218\n",
            "[10/20] MoE best_ep=  5  TestRMSE(orig)=4.5072  TestNLL(z)=1.0794  R²=0.4602\n",
            "[11/20] MoE best_ep=  7  TestRMSE(orig)=4.3232  TestNLL(z)=1.0591  R²=0.5087\n",
            "[12/20] MoE best_ep=169  TestRMSE(orig)=4.2366  TestNLL(z)=0.7541  R²=0.5140\n",
            "[13/20] MoE best_ep= 78  TestRMSE(orig)=4.3122  TestNLL(z)=0.7623  R²=0.5193\n",
            "[14/20] MoE best_ep= 82  TestRMSE(orig)=4.3469  TestNLL(z)=0.8497  R²=0.4849\n",
            "[15/20] MoE best_ep= 76  TestRMSE(orig)=4.5454  TestNLL(z)=0.8636  R²=0.4445\n",
            "[16/20] MoE best_ep= 66  TestRMSE(orig)=4.4661  TestNLL(z)=0.7292  R²=0.4729\n",
            "[17/20] MoE best_ep= 85  TestRMSE(orig)=4.3022  TestNLL(z)=0.7536  R²=0.4815\n",
            "[18/20] MoE best_ep=106  TestRMSE(orig)=4.4230  TestNLL(z)=0.7599  R²=0.4753\n",
            "[19/20] MoE best_ep=  5  TestRMSE(orig)=4.2514  TestNLL(z)=0.9862  R²=0.5248\n",
            "[20/20] MoE best_ep= 88  TestRMSE(orig)=4.3888  TestNLL(z)=0.8182  R²=0.5026\n",
            "\n",
            "== Anchor-MoE (Calibration Ablation: NO CAL) on Protein (10k downsample) ==\n",
            "RMSE (orig) = 4.3730 ± 0.0234\n",
            "NLL  (z)    = 0.8954 ± 0.0296\n",
            " R²         = 0.4920 ± 0.0056\n"
          ]
        }
      ],
      "source": [
        "# === Run on Protein — Calibration Ablation (NO post-hoc CAL), leak-free (downsample to 10k) ===\n",
        "set_seed(1)\n",
        "X, y = load_dataset(\"protein\", verbose=True)   # requires local 'protein.csv'\n",
        "\n",
        "SEED = 1\n",
        "rng  = np.random.RandomState(SEED)\n",
        "\n",
        "# Downsample to 10,000 before any splitting (leak-free)\n",
        "MAX_N = 10_000\n",
        "if X.shape[0] > MAX_N:\n",
        "    idx = rng.choice(X.shape[0], MAX_N, replace=False)\n",
        "    X = X[idx].copy()\n",
        "    y = y[idx].copy()\n",
        "\n",
        "n = X.shape[0]\n",
        "\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    tr = perm[: round(n * 0.9)]\n",
        "    te = perm[round(n * 0.9):]\n",
        "    splits.append((tr, te))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC  = 2, 8, 128, 3\n",
        "LR, EPOCHS     = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS     = 2, 0.05\n",
        "\n",
        "rmses, nlls, r2s = [], [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_no_cal(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll); r2s.append(r2)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses = np.asarray(rmses, dtype=np.float64)\n",
        "nlls  = np.asarray(nlls,  dtype=np.float64)\n",
        "r2s   = np.asarray(r2s,   dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "\n",
        "print(\"\\n== Anchor-MoE (Calibration Ablation: NO CAL) on Protein (10k downsample) ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")\n",
        "print(f\" R²         = {r2s.mean():.4f} ± {se(r2s):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6pOBAjivLZ-9"
      },
      "source": [
        "## Wine"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "q-m0WcaiLc1i",
        "outputId": "6021e014-1c3e-4ca8-e285-752b6d7cadcf"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[WineRed] X.shape=(1599, 11) y.shape=(1599,) | y∈[3.000,8.000]\n",
            "[01/20] MoE best_ep=  1  TestRMSE(orig)=0.5817  TestNLL(z)=1.0947  R²=0.5390\n",
            "[02/20] MoE best_ep=  3  TestRMSE(orig)=0.5605  TestNLL(z)=1.1032  R²=0.4922\n",
            "[03/20] MoE best_ep=  5  TestRMSE(orig)=0.6973  TestNLL(z)=1.4297  R²=0.4524\n",
            "[04/20] MoE best_ep=  3  TestRMSE(orig)=0.6043  TestNLL(z)=1.1490  R²=0.4244\n",
            "[05/20] MoE best_ep=  3  TestRMSE(orig)=0.5705  TestNLL(z)=1.1133  R²=0.4163\n",
            "[06/20] MoE best_ep=  4  TestRMSE(orig)=0.6169  TestNLL(z)=1.1464  R²=0.4065\n",
            "[07/20] MoE best_ep=  7  TestRMSE(orig)=0.6311  TestNLL(z)=1.3918  R²=0.4794\n",
            "[08/20] MoE best_ep=  2  TestRMSE(orig)=0.6536  TestNLL(z)=1.1885  R²=0.4416\n",
            "[09/20] MoE best_ep=  5  TestRMSE(orig)=0.5806  TestNLL(z)=1.1549  R²=0.4523\n",
            "[10/20] MoE best_ep=  9  TestRMSE(orig)=0.6134  TestNLL(z)=1.3013  R²=0.5039\n",
            "[11/20] MoE best_ep=  7  TestRMSE(orig)=0.6394  TestNLL(z)=1.5736  R²=0.4249\n",
            "[12/20] MoE best_ep=  8  TestRMSE(orig)=0.6377  TestNLL(z)=1.1673  R²=0.2898\n",
            "[13/20] MoE best_ep=  2  TestRMSE(orig)=0.6668  TestNLL(z)=1.2035  R²=0.3023\n",
            "[14/20] MoE best_ep=  4  TestRMSE(orig)=0.6107  TestNLL(z)=1.1110  R²=0.2929\n",
            "[15/20] MoE best_ep=  3  TestRMSE(orig)=0.5696  TestNLL(z)=1.1032  R²=0.5506\n",
            "[16/20] MoE best_ep=  3  TestRMSE(orig)=0.5351  TestNLL(z)=0.9547  R²=0.4858\n",
            "[17/20] MoE best_ep=  6  TestRMSE(orig)=0.5888  TestNLL(z)=1.3310  R²=0.5124\n",
            "[18/20] MoE best_ep=  5  TestRMSE(orig)=0.5865  TestNLL(z)=1.1444  R²=0.5090\n",
            "[19/20] MoE best_ep= 30  TestRMSE(orig)=0.6059  TestNLL(z)=1.2044  R²=0.4235\n",
            "[20/20] MoE best_ep= 26  TestRMSE(orig)=0.5977  TestNLL(z)=1.3220  R²=0.4792\n",
            "\n",
            "== Anchor-MoE (Calibration Ablation: NO CAL) on Housing ==\n",
            "RMSE (orig) = 0.6074 ± 0.0087\n",
            "NLL  (z)    = 1.2094 ± 0.0319\n",
            " R²         = 0.4439 ± 0.0170\n"
          ]
        }
      ],
      "source": [
        "# === Run on Housing — Calibration Ablation (NO post-hoc CAL), leak-free ===\n",
        "set_seed(1)\n",
        "X, y = load_dataset(\"wine\", verbose=True)\n",
        "\n",
        "SEED = 1\n",
        "n = X.shape[0]\n",
        "rng = np.random.RandomState(SEED)\n",
        "\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    tr = perm[: round(n * 0.9)]\n",
        "    te = perm[round(n * 0.9):]\n",
        "    splits.append((tr, te))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS     = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS     = 2, 0.05\n",
        "\n",
        "rmses, nlls, r2s = [], [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_no_cal(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll); r2s.append(r2)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "# 汇总\n",
        "rmses = np.asarray(rmses, dtype=np.float64)\n",
        "nlls  = np.asarray(nlls,  dtype=np.float64)\n",
        "r2s   = np.asarray(r2s,   dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "\n",
        "print(\"\\n== Anchor-MoE (Calibration Ablation: NO CAL) on Housing ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")\n",
        "print(f\" R²         = {r2s.mean():.4f} ± {se(r2s):.4f}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "szcalz3CNIZ6",
        "outputId": "2ef62516-73a9-49f7-a933-5f2b46ee0163"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Yacht] X.shape=(308, 6) y.shape=(308,) | y∈[0.010,62.420]\n",
            "[01/20] MoE best_ep=394  TestRMSE(orig)=0.6753  TestNLL(z)=-1.6954  R²=0.9981\n",
            "[02/20] MoE best_ep=298  TestRMSE(orig)=0.4221  TestNLL(z)=-1.9212  R²=0.9991\n",
            "[03/20] MoE best_ep=380  TestRMSE(orig)=0.3478  TestNLL(z)=-1.9736  R²=0.9990\n",
            "[04/20] MoE best_ep=188  TestRMSE(orig)=0.5106  TestNLL(z)=-1.8341  R²=0.9992\n",
            "[05/20] MoE best_ep=303  TestRMSE(orig)=0.3559  TestNLL(z)=-1.9596  R²=0.9996\n",
            "[06/20] MoE best_ep=199  TestRMSE(orig)=0.4025  TestNLL(z)=-1.9257  R²=0.9988\n",
            "[07/20] MoE best_ep=244  TestRMSE(orig)=0.2699  TestNLL(z)=-2.0159  R²=0.9982\n",
            "[08/20] MoE best_ep= 62  TestRMSE(orig)=0.8250  TestNLL(z)=-1.5515  R²=0.9968\n",
            "[09/20] MoE best_ep=188  TestRMSE(orig)=0.4697  TestNLL(z)=-1.8853  R²=0.9987\n",
            "[10/20] MoE best_ep=289  TestRMSE(orig)=0.3375  TestNLL(z)=-1.9937  R²=0.9987\n",
            "[11/20] MoE best_ep=267  TestRMSE(orig)=0.6499  TestNLL(z)=-1.7082  R²=0.9986\n",
            "[12/20] MoE best_ep=178  TestRMSE(orig)=0.7129  TestNLL(z)=-1.7633  R²=0.9967\n",
            "[13/20] MoE best_ep=190  TestRMSE(orig)=0.5820  TestNLL(z)=-1.7790  R²=0.9987\n",
            "[14/20] MoE best_ep=179  TestRMSE(orig)=0.5094  TestNLL(z)=-1.8742  R²=0.9990\n",
            "[15/20] MoE best_ep=103  TestRMSE(orig)=0.9293  TestNLL(z)=-1.6204  R²=0.9965\n",
            "[16/20] MoE best_ep=317  TestRMSE(orig)=0.3636  TestNLL(z)=-1.9592  R²=0.9993\n",
            "[17/20] MoE best_ep= 58  TestRMSE(orig)=0.3866  TestNLL(z)=-1.8947  R²=0.9991\n",
            "[18/20] MoE best_ep= 78  TestRMSE(orig)=0.5729  TestNLL(z)=-1.7719  R²=0.9984\n",
            "[19/20] MoE best_ep=397  TestRMSE(orig)=0.4313  TestNLL(z)=-1.9039  R²=0.9992\n",
            "[20/20] MoE best_ep=341  TestRMSE(orig)=0.7196  TestNLL(z)=-1.6631  R²=0.9985\n",
            "\n",
            "== Anchor-MoE (Calibration Ablation: NO CAL) on Housing ==\n",
            "RMSE (orig) = 0.5237 ± 0.0402\n",
            "NLL  (z)    = -1.8347 ± 0.0299\n",
            " R²         = 0.9985 ± 0.0002\n"
          ]
        }
      ],
      "source": [
        "# === Run on Housing — Calibration Ablation (NO post-hoc CAL), leak-free ===\n",
        "set_seed(1)\n",
        "X, y = load_dataset(\"yacht\", verbose=True)\n",
        "\n",
        "SEED = 1\n",
        "n = X.shape[0]\n",
        "rng = np.random.RandomState(SEED)\n",
        "\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    tr = perm[: round(n * 0.9)]\n",
        "    te = perm[round(n * 0.9):]\n",
        "    splits.append((tr, te))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS     = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS     = 2, 0.05\n",
        "\n",
        "rmses, nlls, r2s = [], [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_no_cal(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll); r2s.append(r2)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "# 汇总\n",
        "rmses = np.asarray(rmses, dtype=np.float64)\n",
        "nlls  = np.asarray(nlls,  dtype=np.float64)\n",
        "r2s   = np.asarray(r2s,   dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "\n",
        "print(\"\\n== Anchor-MoE (Calibration Ablation: NO CAL) on Housing ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")\n",
        "print(f\" R²         = {r2s.mean():.4f} ± {se(r2s):.4f}\")"
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "gpuType": "V6E1",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
