{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hdwtPrgi1tUc"
      },
      "source": [
        "# Ablation of Router"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PZrWFBqm1y3Z"
      },
      "source": [
        "\n",
        "## Model Definition\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4EPCdXFC2lFL"
      },
      "outputs": [],
      "source": [
        "# =========================\n",
        "# Router Ablation Top Cell (Fixed)\n",
        "# =========================\n",
        "import os, math, gc\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "\n",
        "from sklearn.model_selection import train_test_split\n",
        "from sklearn.metrics import mean_squared_error, r2_score\n",
        "from sklearn.ensemble import GradientBoostingRegressor\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "\n",
        "# ---------- Globals ----------\n",
        "LOG2PI  = math.log(2*math.pi)\n",
        "DEVICE  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "torch.set_default_dtype(torch.float32)\n",
        "\n",
        "# ---------- Repro ----------\n",
        "def set_seed(seed: int = 1):\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "        torch.backends.cudnn.deterministic = True\n",
        "        torch.backends.cudnn.benchmark = False\n",
        "\n",
        "# ---------- Utils ----------\n",
        "def rmse_score(y_true, y_pred):\n",
        "    return float(np.sqrt(mean_squared_error(y_true, y_pred)))\n",
        "\n",
        "def zscore_fit(y_train):\n",
        "    my = float(np.mean(y_train))\n",
        "    sy = float(np.std(y_train) + 1e-8)\n",
        "    return my, sy\n",
        "\n",
        "def _topk_mask(w, k=2):\n",
        "    _, topi = torch.topk(w, k, dim=-1)\n",
        "    mask = torch.zeros_like(w).scatter_(-1, topi, 1.0)\n",
        "    w2 = w * mask\n",
        "    return w2 / (w2.sum(dim=-1, keepdim=True) + 1e-12)\n",
        "\n",
        "def _topk_mask_smooth(w, k=2, eps=0.05):\n",
        "    _, topi = torch.topk(w, k, dim=-1)\n",
        "    mask = torch.zeros_like(w).scatter_(-1, topi, 1.0)\n",
        "    w_top = w * mask\n",
        "    w_top = w_top / (w_top.sum(dim=-1, keepdim=True) + 1e-12)\n",
        "    return (1.0 - eps) * w_top + (eps / k) * mask\n",
        "\n",
        "# ---------- Dataset Loader (fixed for Energy & Naval) ----------\n",
        "def load_dataset(name: str, verbose: bool = False):\n",
        "   \n",
        "    def _clean_numeric(df: pd.DataFrame) -> pd.DataFrame:\n",
        "        df = df.apply(pd.to_numeric, errors=\"coerce\")\n",
        "        df = df.replace([np.inf, -np.inf], np.nan).dropna(axis=0)\n",
        "        return df\n",
        "\n",
        "    name = name.lower().strip()\n",
        "\n",
        "    if name == \"yacht\":\n",
        "        df = pd.read_csv(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/00243/yacht_hydrodynamics.data\",\n",
        "            header=None, sep=r\"\\s+\", engine=\"python\", comment=\"#\", skip_blank_lines=True\n",
        "        )\n",
        "        if df.shape[1] == 1:\n",
        "            df = df[0].astype(str).str.strip().str.split(r\"\\s+\", expand=True)\n",
        "        df = _clean_numeric(df); assert df.shape[1] >= 7\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        if verbose: print(f\"[Yacht] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"housing\":\n",
        "        df = pd.read_csv(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data\",\n",
        "            header=None, sep=r\"\\s+\", engine=\"python\", comment=\"#\", skip_blank_lines=True\n",
        "        )\n",
        "        if df.shape[1] == 1:\n",
        "            df = df[0].astype(str).str.strip().str.split(r\"\\s+\", expand=True)\n",
        "        df = _clean_numeric(df); assert df.shape[1] >= 14\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        if verbose: print(f\"[Housing] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"concrete\":\n",
        "        df = pd.read_excel(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/concrete/compressive/Concrete_Data.xls\"\n",
        "        )\n",
        "        df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        if verbose: print(f\"[Concrete] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"wine\":\n",
        "        df = pd.read_csv(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv\",\n",
        "            delimiter=\";\"\n",
        "        )\n",
        "        df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        if verbose: print(f\"[WineRed] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"energy\":\n",
        "        df = pd.read_excel(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/00242/ENB2012_data.xlsx\"\n",
        "        )\n",
        "        df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-2].to_numpy(np.float64)  # 丢掉 Y1、Y2\n",
        "        y = df.iloc[:, -1].to_numpy(np.float64)   # 取 Y2\n",
        "        if verbose: print(f\"[Energy] X.shape={X.shape} y.shape={y.shape} | y∈[{np.min(y):.3f},{np.max(y):.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"protein\":\n",
        "        local = \"data/protein.csv\"\n",
        "        if not os.path.exists(local): raise FileNotFoundError(\"[protein]  'protein.csv'\")\n",
        "        df = pd.read_csv(local, sep=None, engine=\"python\")\n",
        "        df = _clean_numeric(df)\n",
        "        y = df.iloc[:, 0].to_numpy(np.float64)\n",
        "        X = df.iloc[:, 1:].to_numpy(np.float64)\n",
        "        if verbose:\n",
        "            y_min, y_max = float(np.nanmin(y)), float(np.nanmax(y))\n",
        "            print(f\"[Protein(local)] X.shape={X.shape} y.shape={y.shape} | y∈[{y_min:.3f},{y_max:.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"naval\":\n",
        "        local = \"data/naval.txt\"\n",
        "        if not os.path.exists(local): raise FileNotFoundError(\"[naval]  'naval.txt'\")\n",
        "        df = pd.read_csv(local, sep=r\"\\s+\", header=None, engine=\"python\")\n",
        "        df = _clean_numeric(df)\n",
        "        # UCI Naval 常见有 2 个目标；这里默认取“最后一列”为 y\n",
        "        # 如需 Y1，可改成 y = df.iloc[:, -2]\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64)\n",
        "        y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        if verbose: print(f\"[Naval(local)] X.shape={X.shape} y.shape={y.shape} | y∈[{np.min(y):.6f},{np.max(y):.6f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"kin8nm\":\n",
        "        local = \"data/kin8nm.csv\"\n",
        "        if not os.path.exists(local): raise FileNotFoundError(\"[kin8nm]  'kin8nm.csv'\")\n",
        "        df = pd.read_csv(local); df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        if verbose: print(f\"[Kin8nm] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"msd\":\n",
        "        local = \"data/YearPredictionMSD.txt\"\n",
        "        if not os.path.exists(local): raise FileNotFoundError(\"[msd]  'YearPredictionMSD.txt'\")\n",
        "        df = pd.read_csv(local, header=None); df = df.iloc[:, ::-1]\n",
        "        df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        if verbose: print(f\"[MSD] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"power\":\n",
        "        local = \"data/power.xlsx\"  \n",
        "        if not os.path.exists(local): raise FileNotFoundError(\"[power]  'power.xlsx'\")\n",
        "        df = pd.read_excel(local); df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        if verbose: print(f\"[Power(local)] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    raise ValueError(\"Unknown dataset.\")\n",
        "\n",
        "# ---------- MoE blocks (Router ablation) ----------\n",
        "class Projection(nn.Module):\n",
        "    def __init__(self, d, D):\n",
        "        super().__init__()\n",
        "        self.w = nn.Linear(d, D, bias=True)\n",
        "        nn.init.xavier_uniform_(self.w.weight); nn.init.zeros_(self.w.bias)\n",
        "    def forward(self, x): return self.w(x)\n",
        "\n",
        "class Window(nn.Module):\n",
        "    def __init__(self, K, D, min_log_s=-2.5, max_log_s=1.0):\n",
        "        super().__init__()\n",
        "        self.c         = nn.Parameter(torch.randn(K, D))\n",
        "        self.log_s     = nn.Parameter(torch.zeros(K, D))\n",
        "        self.min_log_s = min_log_s; self.max_log_s = max_log_s\n",
        "    def forward(self, z):\n",
        "        log_s = torch.clamp(self.log_s, min=self.min_log_s, max=self.max_log_s)\n",
        "        diff2 = ((z[:, None] - self.c)**2) / (2 * torch.exp(log_s)**2)\n",
        "        return torch.exp(-diff2.sum(dim=-1)) + 1e-12  # [B,K]\n",
        "\n",
        "class ExpertMDN(nn.Module):\n",
        "    def __init__(self, d, h, nc, sigma_min=5e-2, sigma_max=1.0, learn_mean=True):\n",
        "        super().__init__()\n",
        "        self.net = nn.Sequential(\n",
        "            nn.Linear(d, h), nn.ReLU(),\n",
        "            nn.Linear(h, h), nn.ReLU()\n",
        "        )\n",
        "        self.logits = nn.Linear(h, nc)\n",
        "        self.learn_mean = learn_mean\n",
        "        if learn_mean:\n",
        "            self.means  = nn.Linear(h, nc)\n",
        "        self.log_sc = nn.Linear(h, nc)\n",
        "        self.sigma_min = float(sigma_min)\n",
        "        self.sigma_max = float(sigma_max)\n",
        "        for m in self.modules():\n",
        "            if isinstance(m, nn.Linear):\n",
        "                nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)\n",
        "    def forward(self, x):\n",
        "        h  = self.net(x)\n",
        "        pi = F.softmax(self.logits(h), dim=-1)\n",
        "        mu = self.means(h) if self.learn_mean else None\n",
        "        sg = torch.exp(self.log_sc(h)).clamp(self.sigma_min, self.sigma_max)\n",
        "        return pi, mu, sg\n",
        "\n",
        "class BLRMoE_NoRouter(nn.Module):\n",
        "  \n",
        "    def __init__(self, d, D, K, hid, nc,\n",
        "                 mean_mode='anchor+delta', delta_l2=3e-3,\n",
        "                 l2_win=1e-4,\n",
        "                 sigma_min=5e-2, sigma_max=1.0,\n",
        "                 topk=2, smooth_eps=0.05):\n",
        "        super().__init__()\n",
        "        assert mean_mode in ['anchor','anchor+delta','free']\n",
        "        assert 1 <= topk <= K\n",
        "        self.mean_mode = mean_mode\n",
        "        self.delta_l2  = float(delta_l2)\n",
        "        self.proj   = Projection(d, D)\n",
        "        self.win    = Window(K, D)\n",
        "        learn_mean = (mean_mode != 'anchor')\n",
        "        self.exps   = nn.ModuleList([ExpertMDN(d, hid, nc, sigma_min, sigma_max, learn_mean=learn_mean) for _ in range(K)])\n",
        "        self.l2_win     = l2_win\n",
        "        self.topk       = topk\n",
        "        self.smooth_eps = smooth_eps\n",
        "        self.K = K; self.nc = nc\n",
        "\n",
        "    def _gating(self, z, train=True):\n",
        "        w = self.win(z)  # [B,K] >= 0\n",
        "        w = w / (w.sum(dim=-1, keepdim=True) + 1e-12)\n",
        "        w = _topk_mask_smooth(w, k=self.topk, eps=self.smooth_eps) if train else _topk_mask(w, k=self.topk)\n",
        "        return w\n",
        "\n",
        "    def _mixture_params(self, X, train=True):\n",
        "        z = self.proj(X)\n",
        "        w = self._gating(z, train=train)\n",
        "\n",
        "        B, K, C = X.size(0), self.K, self.nc\n",
        "        Pi = torch.full((B, K, C), 1.0/C, device=X.device)\n",
        "        Mu = torch.zeros(B, K, C, device=X.device)\n",
        "        Sg = torch.ones(B, K, C, device=X.device)\n",
        "\n",
        "        _, topi = torch.topk(w, self.topk, dim=-1)\n",
        "        uniq = torch.unique(topi)\n",
        "        for j in uniq.tolist():\n",
        "            pi_j, mu_j, sg_j = self.exps[j](X)\n",
        "            mask = (topi == j).any(dim=1).float().unsqueeze(-1)\n",
        "            Pi[:, j, :] = pi_j * mask + Pi[:, j, :]*(1 - mask)\n",
        "            if mu_j is not None:\n",
        "                Mu[:, j, :] = mu_j * mask + Mu[:, j, :]*(1 - mask)\n",
        "            Sg[:, j, :] = sg_j * mask + Sg[:, j, :]*(1 - mask)\n",
        "        return w, Pi, Mu, Sg, z\n",
        "\n",
        "    def nll(self, X, y_z, mu_anchor_z=None, epoch=1, warmup_epochs=150):\n",
        "        w, Pi, Mu, Sg, z = self._mixture_params(X, train=self.training)\n",
        "\n",
        "        if self.mean_mode == 'anchor':\n",
        "            assert mu_anchor_z is not None\n",
        "            Mu_eff = mu_anchor_z[:, None, None]\n",
        "            delta_pen = 0.0\n",
        "        elif self.mean_mode == 'anchor+delta':\n",
        "            assert mu_anchor_z is not None\n",
        "            Mu_eff = mu_anchor_z[:, None, None] + Mu\n",
        "            delta_pen = (Mu**2).mean() * self.delta_l2\n",
        "        else:\n",
        "            Mu_eff = Mu\n",
        "            delta_pen = 0.0\n",
        "\n",
        "        yv = y_z[:, None, None]\n",
        "        logp = -0.5 * ((yv - Mu_eff)/Sg)**2 - torch.log(Sg) - 0.5*LOG2PI\n",
        "\n",
        "        w3   = w[:, :, None]\n",
        "        logw = torch.where(w3 > 0, torch.log(w3 + 1e-12), torch.full_like(w3, -1e9))\n",
        "        logpi= torch.log(Pi + 1e-12)\n",
        "\n",
        "        logmix = torch.logsumexp(logw + logpi + logp, dim=(1,2))\n",
        "        nll = -logmix.mean()\n",
        "\n",
        "        if self.training:\n",
        "            l2w = (self.win.log_s**2).mean()\n",
        "            nll = nll + self.l2_win*l2w + delta_pen\n",
        "        return nll\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def predict_mean_var(self, X, mu_anchor_z=None):\n",
        "        w, Pi, Mu, Sg, _ = self._mixture_params(X, train=False)\n",
        "        if self.mean_mode == 'anchor':\n",
        "            assert mu_anchor_z is not None\n",
        "            Mu_eff = mu_anchor_z[:, None, None]\n",
        "        elif self.mean_mode == 'anchor+delta':\n",
        "            assert mu_anchor_z is not None\n",
        "            Mu_eff = mu_anchor_z[:, None, None] + Mu\n",
        "        else:\n",
        "            Mu_eff = Mu\n",
        "\n",
        "        mu_e  = (Pi * Mu_eff).sum(dim=2)              # [B,K]\n",
        "        m2_e  = (Pi * (Sg**2 + Mu_eff**2)).sum(dim=2) # [B,K]\n",
        "        mu_z  = (w * mu_e).sum(dim=1)                 # [B]\n",
        "        second= (w * m2_e).sum(dim=1)                 # [B]\n",
        "        var_z = torch.clamp(second - mu_z**2, min=1e-9)\n",
        "        return mu_z, var_z\n",
        "\n",
        "# ---------- One-split training (No-leak; with Anchor + CAL) ----------\n",
        "def train_one_split_router_off(\n",
        "    X_all, y_all, X_te, y_te,\n",
        "    standardize_x=True,\n",
        "    D=2, K=8, HID=128, NC=3,\n",
        "    LR=1e-3, EPOCHS=400,\n",
        "    MEAN_MODE='anchor+delta',\n",
        "    DELTA_L2=3e-3,\n",
        "    SIGMA_MIN=5e-2, SIGMA_MAX=1.0,\n",
        "    TOPK=2, SMOOTH_EPS=0.05,\n",
        "    seed=1\n",
        "):\n",
        "    \n",
        "    # ---- TV/CAL ----\n",
        "    X_tv, X_cal, y_tv, y_cal = train_test_split(X_all, y_all, test_size=0.125, random_state=seed)\n",
        "    # ---- TV -> TR/VA ----\n",
        "    X_tr, X_va, y_tr, y_va   = train_test_split(X_tv,  y_tv,  test_size=0.2,   random_state=seed)\n",
        "\n",
        "    # 1) Anchor：TR/VA 选 best_it\n",
        "    gbdt_full_tr = GradientBoostingRegressor(\n",
        "        n_estimators=2000, learning_rate=0.05, max_depth=3, subsample=1.0, random_state=seed\n",
        "    ).fit(X_tr, y_tr)\n",
        "    val_rmse = [rmse_score(y_va, p) for p in gbdt_full_tr.staged_predict(X_va)]\n",
        "    best_it  = int(np.argmin(val_rmse)) + 1\n",
        "\n",
        "    gbdt_sub   = GradientBoostingRegressor(n_estimators=best_it, learning_rate=0.05,\n",
        "                                           max_depth=3, subsample=1.0, random_state=seed).fit(X_tr, y_tr)\n",
        "    gbdt_final = GradientBoostingRegressor(n_estimators=best_it, learning_rate=0.05,\n",
        "                                           max_depth=3, subsample=1.0, random_state=seed).fit(X_tv, y_tv)\n",
        "\n",
        "   \n",
        "    my_tr, sy_tr = zscore_fit(y_tr)\n",
        "    my_tv, sy_tv = zscore_fit(y_tv)\n",
        "\n",
        "    mu_tr_anchor = (gbdt_sub.predict(X_tr) - my_tr) / sy_tr\n",
        "    mu_va_anchor = (gbdt_sub.predict(X_va) - my_tr) / sy_tr\n",
        "\n",
        "    \n",
        "    X_tr_aug = np.column_stack([X_tr, mu_tr_anchor])\n",
        "    X_va_aug = np.column_stack([X_va, mu_va_anchor])\n",
        "\n",
        "    \n",
        "    if standardize_x:\n",
        "        mx_tr = X_tr_aug.mean(0, keepdims=True); sx_tr = X_tr_aug.std(0, keepdims=True) + 1e-8\n",
        "    else:\n",
        "        mx_tr = np.zeros((1, X_tr_aug.shape[1])); sx_tr = np.ones((1, X_tr_aug.shape[1]))\n",
        "\n",
        "    def to_tensor(x, yz=None, mx=None, sx=None):\n",
        "        Xt = torch.tensor(((x - mx)/sx).astype(np.float32), device=DEVICE)\n",
        "        if yz is None: return Xt\n",
        "        Yt = torch.tensor(yz.astype(np.float32), device=DEVICE)\n",
        "        return Xt, Yt\n",
        "\n",
        "    y_tr_z = (y_tr - my_tr) / sy_tr\n",
        "    y_va_z = (y_va - my_tr) / sy_tr\n",
        "\n",
        "    model = BLRMoE_NoRouter(\n",
        "        d=X_tr_aug.shape[1], D=D, K=K, hid=HID, nc=NC,\n",
        "        mean_mode=MEAN_MODE, delta_l2=DELTA_L2,\n",
        "        sigma_min=SIGMA_MIN, sigma_max=SIGMA_MAX,\n",
        "        topk=TOPK, smooth_eps=SMOOTH_EPS\n",
        "    ).to(DEVICE)\n",
        "    opt = optim.AdamW(model.parameters(), lr=LR, weight_decay=3e-4)\n",
        "\n",
        "    Xtr_t, ytr_t = to_tensor(X_tr_aug, y_tr_z, mx_tr, sx_tr)\n",
        "    Xva_t, yva_t = to_tensor(X_va_aug, y_va_z, mx_tr, sx_tr)\n",
        "    mu_tr_t = torch.tensor(mu_tr_anchor.astype(np.float32), device=DEVICE)\n",
        "    mu_va_t = torch.tensor(mu_va_anchor.astype(np.float32), device=DEVICE)\n",
        "\n",
        "    best_ep, best_vnll, best_state = 0, +1e9, None\n",
        "    for ep in range(1, EPOCHS+1):\n",
        "        model.train(); opt.zero_grad()\n",
        "        loss = model.nll(Xtr_t, ytr_t, mu_anchor_z=mu_tr_t, epoch=ep, warmup_epochs=150)\n",
        "        loss.backward(); nn.utils.clip_grad_norm_(model.parameters(), 2.0); opt.step()\n",
        "\n",
        "        model.eval()\n",
        "        with torch.no_grad():\n",
        "            vnll = float(model.nll(Xva_t, yva_t, mu_anchor_z=mu_va_t).cpu().item())\n",
        "        if vnll < best_vnll:\n",
        "            best_vnll = vnll; best_ep = ep\n",
        "            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}\n",
        "\n",
        "    mu_tv_anchor = (gbdt_final.predict(X_tv) - my_tv) / sy_tv\n",
        "    mu_te_anchor = (gbdt_final.predict(X_te) - my_tv) / sy_tv\n",
        "    mu_cal_anchor= (gbdt_final.predict(X_cal)- my_tv) / sy_tv\n",
        "\n",
        "    X_tv_aug  = np.column_stack([X_tv,  mu_tv_anchor])\n",
        "    X_te_aug  = np.column_stack([X_te,  mu_te_anchor])\n",
        "    X_cal_aug = np.column_stack([X_cal, mu_cal_anchor])\n",
        "\n",
        "    if standardize_x:\n",
        "        mx_tv = X_tv_aug.mean(0, keepdims=True); sx_tv = X_tv_aug.std(0, keepdims=True) + 1e-8\n",
        "    else:\n",
        "        mx_tv = np.zeros((1, X_tv_aug.shape[1])); sx_tv = np.ones((1, X_tv_aug.shape[1]))\n",
        "\n",
        "    y_tv_z = (y_tv - my_tv) / sy_tv\n",
        "    y_te_z = (y_te - my_tv) / sy_tv\n",
        "\n",
        "    model.load_state_dict(best_state)\n",
        "    model.train()\n",
        "    opt = optim.AdamW(model.parameters(), lr=LR, weight_decay=3e-4)\n",
        "\n",
        "    Xtv_t, ytv_t = to_tensor(X_tv_aug, y_tv_z, mx_tv, sx_tv)\n",
        "    Xte_t        = to_tensor(X_te_aug, None,   mx_tv, sx_tv)\n",
        "    Xcal_t       = to_tensor(X_cal_aug, None,  mx_tv, sx_tv)\n",
        "    mu_tv_t  = torch.tensor(mu_tv_anchor.astype(np.float32),  device=DEVICE)\n",
        "    mu_te_t  = torch.tensor(mu_te_anchor.astype(np.float32),  device=DEVICE)\n",
        "    mu_cal_t = torch.tensor(mu_cal_anchor.astype(np.float32), device=DEVICE)\n",
        "\n",
        "    for ep in range(1, best_ep+1):\n",
        "        opt.zero_grad()\n",
        "        loss = model.nll(Xtv_t, ytv_t, mu_anchor_z=mu_tv_t, epoch=ep, warmup_epochs=150)\n",
        "        loss.backward(); nn.utils.clip_grad_norm_(model.parameters(), 2.0); opt.step()\n",
        "\n",
        "    model.eval()\n",
        "    with torch.no_grad():\n",
        "        yte_t = torch.tensor(y_te_z.astype(np.float32), device=DEVICE)\n",
        "        test_nll = float(model.nll(Xte_t, yte_t, mu_anchor_z=mu_te_t).cpu().item())\n",
        "\n",
        "        mu_z_cal, _ = model.predict_mean_var(Xcal_t, mu_anchor_z=mu_cal_t)\n",
        "        mu_cal_orig = mu_z_cal.cpu().numpy().astype(np.float64) * sy_tv + my_tv\n",
        "\n",
        "        mu_z_te, _  = model.predict_mean_var(Xte_t,  mu_anchor_z=mu_te_t)\n",
        "        mu_te_orig  = mu_z_te.cpu().numpy().astype(np.float64) * sy_tv + my_tv\n",
        "\n",
        "    A = np.vstack([mu_cal_orig, np.ones_like(mu_cal_orig)]).T\n",
        "    a, b = np.linalg.lstsq(A, y_cal.astype(np.float64), rcond=None)[0]\n",
        "    mu_te_cal  = a * mu_te_orig + b\n",
        "\n",
        "    rmse = rmse_score(y_te.astype(np.float64), mu_te_cal)\n",
        "    r2   = float(r2_score(y_te.astype(np.float64), mu_te_cal))\n",
        "    return rmse, test_nll, best_ep, r2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OXJD1qxG2ki7"
      },
      "source": [
        "\n",
        "## Housing\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ZekuiPlJ2lC4",
        "outputId": "bf09f83b-2e4a-4aff-92f1-afb74380db6f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Housing] X.shape=(506, 13) y.shape=(506,) | y∈[5.000,50.000]\n",
            "[01/20] best_ep=  8  RMSE=2.6053  NLL(z)=0.3143  R²=0.8787\n",
            "[02/20] best_ep= 16  RMSE=2.7610  NLL(z)=0.4358  R²=0.8786\n",
            "[03/20] best_ep= 14  RMSE=3.4645  NLL(z)=0.7230  R²=0.8544\n",
            "[04/20] best_ep=  9  RMSE=2.4260  NLL(z)=0.3063  R²=0.9396\n",
            "[05/20] best_ep= 14  RMSE=3.2127  NLL(z)=0.9847  R²=0.9189\n",
            "[06/20] best_ep= 11  RMSE=2.6928  NLL(z)=0.4585  R²=0.9108\n",
            "[07/20] best_ep=  9  RMSE=1.8283  NLL(z)=0.0781  R²=0.8972\n",
            "[08/20] best_ep=  7  RMSE=3.5190  NLL(z)=0.4713  R²=0.8252\n",
            "[09/20] best_ep=  9  RMSE=3.3076  NLL(z)=0.7731  R²=0.8687\n",
            "[10/20] best_ep=  8  RMSE=4.3314  NLL(z)=0.6275  R²=0.7964\n",
            "[11/20] best_ep=  9  RMSE=3.1307  NLL(z)=0.8264  R²=0.9006\n",
            "[12/20] best_ep= 11  RMSE=3.2477  NLL(z)=0.1629  R²=0.8459\n",
            "[13/20] best_ep= 14  RMSE=2.2793  NLL(z)=0.7567  R²=0.9227\n",
            "[14/20] best_ep=  9  RMSE=2.2978  NLL(z)=0.4888  R²=0.9504\n",
            "[15/20] best_ep= 10  RMSE=2.7671  NLL(z)=0.5348  R²=0.8992\n",
            "[16/20] best_ep=  9  RMSE=3.1925  NLL(z)=0.4066  R²=0.8955\n",
            "[17/20] best_ep= 10  RMSE=2.8554  NLL(z)=0.4542  R²=0.8548\n",
            "[18/20] best_ep= 11  RMSE=2.9956  NLL(z)=0.6046  R²=0.8768\n",
            "[19/20] best_ep=  9  RMSE=2.2968  NLL(z)=0.2571  R²=0.9393\n",
            "[20/20] best_ep=  8  RMSE=2.4659  NLL(z)=0.5232  R²=0.9434\n",
            "\n",
            "== MoE (Router Ablation, with Anchor + CAL, no-leak) on Housing ==\n",
            "RMSE = 2.8839 ± 0.1269\n",
            "NLL  = 0.5094 ± 0.0513\n"
          ]
        }
      ],
      "source": [
        "set_seed(1)\n",
        "X, y = load_dataset(\"housing\", verbose=True)\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "rng = np.random.RandomState(1)\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    splits.append((perm[: round(n*0.9)], perm[round(n*0.9):]))\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for i, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_router_off(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        D=2, K=8, HID=128, NC=3,\n",
        "        LR=1e-3, EPOCHS=400,\n",
        "        SIGMA_MIN=5e-2, SIGMA_MAX=1.0,\n",
        "        TOPK=2, SMOOTH_EPS=0.05,\n",
        "        seed=1+i\n",
        "    )\n",
        "    print(f\"[{i:02d}/20] best_ep={best_ep:3d}  RMSE={rmse:.4f}  NLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "\n",
        "rmses = np.array(rmses); nlls = np.array(nlls)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "print(\"\\n== MoE (Router Ablation, with Anchor + CAL, no-leak) on Housing ==\")\n",
        "print(f\"RMSE = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hmXkXQBV4Rme"
      },
      "source": [
        "## Concrete\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "M4vIFiC11LIc",
        "outputId": "2e7a6061-3266-49f1-fac6-0b680c8a7712"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Concrete] X.shape=(1030, 8) y.shape=(1030,) | y∈[2.332,82.599]\n",
            "[01/20] best_ep= 14  RMSE=4.7123  NLL(z)=0.2781  R²=0.9275\n",
            "[02/20] best_ep= 17  RMSE=4.2180  NLL(z)=0.3678  R²=0.9310\n",
            "[03/20] best_ep= 13  RMSE=3.7558  NLL(z)=-0.1359  R²=0.9521\n",
            "[04/20] best_ep= 14  RMSE=4.6757  NLL(z)=0.1788  R²=0.9204\n",
            "[05/20] best_ep= 16  RMSE=5.1519  NLL(z)=0.5474  R²=0.8892\n",
            "[06/20] best_ep= 15  RMSE=3.5783  NLL(z)=-0.2381  R²=0.9515\n",
            "[07/20] best_ep= 17  RMSE=4.7975  NLL(z)=0.4492  R²=0.9219\n",
            "[08/20] best_ep= 18  RMSE=5.0889  NLL(z)=0.1390  R²=0.9016\n",
            "[09/20] best_ep= 15  RMSE=4.1946  NLL(z)=-0.0887  R²=0.9396\n",
            "[10/20] best_ep= 15  RMSE=5.1534  NLL(z)=0.7313  R²=0.9083\n",
            "[11/20] best_ep= 18  RMSE=3.3168  NLL(z)=-0.0573  R²=0.9546\n",
            "[12/20] best_ep= 16  RMSE=4.3361  NLL(z)=0.2097  R²=0.9297\n",
            "[13/20] best_ep= 14  RMSE=4.4016  NLL(z)=0.1952  R²=0.9328\n",
            "[14/20] best_ep= 16  RMSE=3.5618  NLL(z)=0.1405  R²=0.9490\n",
            "[15/20] best_ep= 16  RMSE=4.8248  NLL(z)=0.3035  R²=0.8868\n",
            "[16/20] best_ep= 18  RMSE=5.1922  NLL(z)=0.3494  R²=0.8770\n",
            "[17/20] best_ep= 13  RMSE=3.6228  NLL(z)=-0.0118  R²=0.9489\n",
            "[18/20] best_ep= 18  RMSE=4.6753  NLL(z)=0.0750  R²=0.9148\n",
            "[19/20] best_ep= 15  RMSE=5.2829  NLL(z)=0.1386  R²=0.9061\n",
            "[20/20] best_ep= 17  RMSE=4.2535  NLL(z)=0.3735  R²=0.9246\n",
            "\n",
            "== MoE (Router Ablation, with Anchor + CAL, no-leak) on Housing ==\n",
            "RMSE = 4.4397 ± 0.1380\n",
            "NLL  = 0.1973 ± 0.0535\n"
          ]
        }
      ],
      "source": [
        "set_seed(1)\n",
        "X, y = load_dataset(\"concrete\", verbose=True)\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "rng = np.random.RandomState(1)\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    splits.append((perm[: round(n*0.9)], perm[round(n*0.9):]))\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for i, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_router_off(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        D=2, K=8, HID=128, NC=3,\n",
        "        LR=1e-3, EPOCHS=400,\n",
        "        SIGMA_MIN=5e-2, SIGMA_MAX=1.0,\n",
        "        TOPK=2, SMOOTH_EPS=0.05,\n",
        "        seed=1+i\n",
        "    )\n",
        "    print(f\"[{i:02d}/20] best_ep={best_ep:3d}  RMSE={rmse:.4f}  NLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "\n",
        "rmses = np.array(rmses); nlls = np.array(nlls)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "print(\"\\n== MoE (Router Ablation, with Anchor + CAL, no-leak) on Housing ==\")\n",
        "print(f\"RMSE = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GgHfkdPK6B-t"
      },
      "source": [
        "## Energy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "tFjoER9B4R_Q",
        "outputId": "4737f003-cf99-42bb-f18a-5b636a74a217"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Energy] X.shape=(768, 8) y.shape=(768,) | y∈[10.900,48.030]\n",
            "[01/20] best_ep= 41  RMSE=1.0822  NLL(z)=-0.9052  R²=0.9873\n",
            "[02/20] best_ep=291  RMSE=1.1954  NLL(z)=-0.8744  R²=0.9861\n",
            "[03/20] best_ep= 23  RMSE=1.0062  NLL(z)=-0.9035  R²=0.9886\n",
            "[04/20] best_ep= 40  RMSE=1.4348  NLL(z)=-0.8177  R²=0.9711\n",
            "[05/20] best_ep= 35  RMSE=1.4789  NLL(z)=-0.5157  R²=0.9768\n",
            "[06/20] best_ep= 31  RMSE=1.5282  NLL(z)=-0.7360  R²=0.9713\n",
            "[07/20] best_ep= 90  RMSE=1.2590  NLL(z)=-0.1996  R²=0.9853\n",
            "[08/20] best_ep= 25  RMSE=1.5734  NLL(z)=-0.7308  R²=0.9751\n",
            "[09/20] best_ep= 18  RMSE=0.8936  NLL(z)=-0.8905  R²=0.9904\n",
            "[10/20] best_ep= 23  RMSE=1.3440  NLL(z)=-0.9223  R²=0.9795\n",
            "[11/20] best_ep= 95  RMSE=1.0525  NLL(z)=-0.5799  R²=0.9881\n",
            "[12/20] best_ep= 31  RMSE=1.1252  NLL(z)=-1.0880  R²=0.9856\n",
            "[13/20] best_ep= 19  RMSE=1.4780  NLL(z)=-1.0114  R²=0.9776\n",
            "[14/20] best_ep=157  RMSE=1.1230  NLL(z)=-0.8295  R²=0.9855\n",
            "[15/20] best_ep=358  RMSE=1.0410  NLL(z)=-0.8093  R²=0.9880\n",
            "[16/20] best_ep=106  RMSE=1.1269  NLL(z)=-0.8397  R²=0.9867\n",
            "[17/20] best_ep= 39  RMSE=1.4285  NLL(z)=-0.3001  R²=0.9792\n",
            "[18/20] best_ep=101  RMSE=1.2265  NLL(z)=-0.5228  R²=0.9827\n",
            "[19/20] best_ep=156  RMSE=1.1442  NLL(z)=-0.9636  R²=0.9814\n",
            "[20/20] best_ep= 30  RMSE=1.1009  NLL(z)=-0.7154  R²=0.9868\n",
            "\n",
            "== MoE (Router Ablation, with Anchor + CAL, no-leak) on Housing ==\n",
            "RMSE = 1.2321 ± 0.0440\n",
            "NLL  = -0.7578 ± 0.0514\n"
          ]
        }
      ],
      "source": [
        "set_seed(1)\n",
        "X, y = load_dataset(\"energy\", verbose=True)\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "rng = np.random.RandomState(1)\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    splits.append((perm[: round(n*0.9)], perm[round(n*0.9):]))\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for i, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_router_off(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        D=2, K=8, HID=128, NC=3,\n",
        "        LR=1e-3, EPOCHS=400,\n",
        "        SIGMA_MIN=5e-2, SIGMA_MAX=1.0,\n",
        "        TOPK=2, SMOOTH_EPS=0.05,\n",
        "        seed=1+i\n",
        "    )\n",
        "    print(f\"[{i:02d}/20] best_ep={best_ep:3d}  RMSE={rmse:.4f}  NLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "\n",
        "rmses = np.array(rmses); nlls = np.array(nlls)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "print(\"\\n== MoE (Router Ablation, with Anchor + CAL, no-leak) on Housing ==\")\n",
        "print(f\"RMSE = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y2yioONm7RxT"
      },
      "source": [
        "## Kin8nm"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "94nabq_H7UUR",
        "outputId": "5ffbd8a4-f1dc-4f92-809b-3f9a9121a43f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Kin8nm] X.shape=(8192, 8) y.shape=(8192,) | y∈[0.040,1.459]\n",
            "[01/20] best_ep=  6  RMSE=0.1528  NLL(z)=0.9933  R²=0.6765\n",
            "[02/20] best_ep=  6  RMSE=0.1556  NLL(z)=1.0332  R²=0.6574\n",
            "[03/20] best_ep=  6  RMSE=0.1527  NLL(z)=1.1308  R²=0.6509\n",
            "[04/20] best_ep=  6  RMSE=0.1450  NLL(z)=0.9631  R²=0.6991\n",
            "[05/20] best_ep=  6  RMSE=0.1509  NLL(z)=1.0290  R²=0.6280\n",
            "[06/20] best_ep=  7  RMSE=0.1424  NLL(z)=1.0792  R²=0.7224\n",
            "[07/20] best_ep=  6  RMSE=0.1557  NLL(z)=1.0210  R²=0.6240\n",
            "[08/20] best_ep=  7  RMSE=0.1397  NLL(z)=0.9290  R²=0.7451\n",
            "[09/20] best_ep=  6  RMSE=0.1436  NLL(z)=0.9775  R²=0.6977\n",
            "[10/20] best_ep=  5  RMSE=0.1506  NLL(z)=1.0547  R²=0.6836\n",
            "[11/20] best_ep=  7  RMSE=0.1393  NLL(z)=0.9625  R²=0.7356\n",
            "[12/20] best_ep=  7  RMSE=0.1529  NLL(z)=1.0555  R²=0.6600\n",
            "[13/20] best_ep=  5  RMSE=0.1503  NLL(z)=0.9522  R²=0.6933\n",
            "[14/20] best_ep=  6  RMSE=0.1456  NLL(z)=0.9856  R²=0.6907\n",
            "[15/20] best_ep=  6  RMSE=0.1412  NLL(z)=0.8974  R²=0.7195\n",
            "[16/20] best_ep=  6  RMSE=0.1494  NLL(z)=0.9905  R²=0.6737\n",
            "[17/20] best_ep=  6  RMSE=0.1461  NLL(z)=1.0277  R²=0.6963\n",
            "[18/20] best_ep=  7  RMSE=0.1462  NLL(z)=0.9722  R²=0.6816\n",
            "[19/20] best_ep=  7  RMSE=0.1499  NLL(z)=1.0678  R²=0.6951\n",
            "[20/20] best_ep=  6  RMSE=0.1550  NLL(z)=0.9839  R²=0.6619\n",
            "\n",
            "== MoE (Router Ablation, with Anchor + CAL, no-leak) on Housing ==\n",
            "RMSE = 0.1483 ± 0.0012\n",
            "NLL  = 1.0053 ± 0.0125\n"
          ]
        }
      ],
      "source": [
        "set_seed(1)\n",
        "X, y = load_dataset(\"kin8nm\", verbose=True)\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "rng = np.random.RandomState(1)\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    splits.append((perm[: round(n*0.9)], perm[round(n*0.9):]))\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for i, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_router_off(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        D=2, K=8, HID=128, NC=3,\n",
        "        LR=1e-3, EPOCHS=400,\n",
        "        SIGMA_MIN=5e-2, SIGMA_MAX=1.0,\n",
        "        TOPK=2, SMOOTH_EPS=0.05,\n",
        "        seed=1+i\n",
        "    )\n",
        "    print(f\"[{i:02d}/20] best_ep={best_ep:3d}  RMSE={rmse:.4f}  NLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "\n",
        "rmses = np.array(rmses); nlls = np.array(nlls)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "print(\"\\n== MoE (Router Ablation, with Anchor + CAL, no-leak) on Housing ==\")\n",
        "print(f\"RMSE = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BOIAEc-AFrCm"
      },
      "source": [
        "## Naval"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "R2C26TAqFu6w",
        "outputId": "42814c5b-3548-4526-cf72-84c46f1e52bc"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Naval(local)] X.shape=(11934, 17) y.shape=(11934,) | y∈[0.975000,1.000000]\n",
            "[01/20] best_ep= 87  RMSE=0.0006  NLL(z)=-1.1981  R²=0.9934\n",
            "[02/20] best_ep= 82  RMSE=0.0009  NLL(z)=-1.1024  R²=0.9854\n",
            "[03/20] best_ep= 99  RMSE=0.0007  NLL(z)=-0.9135  R²=0.9922\n",
            "[04/20] best_ep=212  RMSE=0.0006  NLL(z)=-1.2326  R²=0.9932\n",
            "[05/20] best_ep=144  RMSE=0.0007  NLL(z)=-0.8919  R²=0.9923\n",
            "[06/20] best_ep= 64  RMSE=0.0007  NLL(z)=-0.8797  R²=0.9925\n",
            "[07/20] best_ep=121  RMSE=0.0006  NLL(z)=-1.0320  R²=0.9934\n",
            "[08/20] best_ep=397  RMSE=0.0006  NLL(z)=-1.1559  R²=0.9946\n",
            "[09/20] best_ep=101  RMSE=0.0006  NLL(z)=-1.1355  R²=0.9930\n",
            "[10/20] best_ep=232  RMSE=0.0006  NLL(z)=-0.9679  R²=0.9928\n",
            "[11/20] best_ep=110  RMSE=0.0006  NLL(z)=-1.0523  R²=0.9935\n",
            "[12/20] best_ep=293  RMSE=0.0006  NLL(z)=-1.1137  R²=0.9936\n",
            "[13/20] best_ep=127  RMSE=0.0007  NLL(z)=-1.0909  R²=0.9921\n",
            "[14/20] best_ep=259  RMSE=0.0005  NLL(z)=-1.2004  R²=0.9948\n",
            "[15/20] best_ep=151  RMSE=0.0006  NLL(z)=-1.1489  R²=0.9939\n",
            "[16/20] best_ep=124  RMSE=0.0006  NLL(z)=-1.1279  R²=0.9933\n",
            "[17/20] best_ep=130  RMSE=0.0005  NLL(z)=-1.1922  R²=0.9943\n",
            "[18/20] best_ep=117  RMSE=0.0007  NLL(z)=-1.1506  R²=0.9904\n",
            "[19/20] best_ep=267  RMSE=0.0006  NLL(z)=-1.2105  R²=0.9943\n",
            "[20/20] best_ep=398  RMSE=0.0006  NLL(z)=-1.0848  R²=0.9931\n",
            "\n",
            "== MoE (Router Ablation, with Anchor + CAL, no-leak) on Housing ==\n",
            "RMSE = 0.0006 ± 0.0000\n",
            "NLL  = -1.0941 ± 0.0240\n"
          ]
        }
      ],
      "source": [
        "set_seed(1)\n",
        "X, y = load_dataset(\"naval\", verbose=True)\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "rng = np.random.RandomState(1)\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    splits.append((perm[: round(n*0.9)], perm[round(n*0.9):]))\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for i, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_router_off(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        D=2, K=8, HID=128, NC=3,\n",
        "        LR=1e-3, EPOCHS=400,\n",
        "        SIGMA_MIN=5e-2, SIGMA_MAX=1.0,\n",
        "        TOPK=2, SMOOTH_EPS=0.05,\n",
        "        seed=1+i\n",
        "    )\n",
        "    print(f\"[{i:02d}/20] best_ep={best_ep:3d}  RMSE={rmse:.4f}  NLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "\n",
        "rmses = np.array(rmses); nlls = np.array(nlls)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "print(\"\\n== MoE (Router Ablation, with Anchor + CAL, no-leak) on Housing ==\")\n",
        "print(f\"RMSE = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "977xYDKgUfU-"
      },
      "source": [
        "## Power"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "oP64ET9sUe5d",
        "outputId": "a8c3c593-ae05-4105-c121-6a2f1abd7f3d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Power(local)] X.shape=(9568, 4) y.shape=(9568,) | y∈[420.260,495.760]\n",
            "[01/20] best_ep= 25  RMSE=3.2342  NLL(z)=-0.1510  R²=0.9624\n",
            "[02/20] best_ep= 57  RMSE=3.5139  NLL(z)=-0.0653  R²=0.9584\n",
            "[03/20] best_ep= 20  RMSE=3.0139  NLL(z)=-0.2232  R²=0.9699\n",
            "[04/20] best_ep= 22  RMSE=3.0128  NLL(z)=-0.2590  R²=0.9668\n",
            "[05/20] best_ep= 19  RMSE=3.6344  NLL(z)=-0.1241  R²=0.9561\n",
            "[06/20] best_ep= 55  RMSE=3.5673  NLL(z)=-0.0333  R²=0.9565\n",
            "[07/20] best_ep= 25  RMSE=3.0617  NLL(z)=-0.1754  R²=0.9671\n",
            "[08/20] best_ep= 53  RMSE=3.0818  NLL(z)=-0.1687  R²=0.9679\n",
            "[09/20] best_ep= 19  RMSE=3.4341  NLL(z)=-0.1461  R²=0.9604\n",
            "[10/20] best_ep= 95  RMSE=3.3008  NLL(z)=0.0172  R²=0.9618\n",
            "[11/20] best_ep= 21  RMSE=3.0169  NLL(z)=-0.1340  R²=0.9681\n",
            "[12/20] best_ep= 22  RMSE=3.3524  NLL(z)=-0.0448  R²=0.9623\n",
            "[13/20] best_ep= 45  RMSE=2.9765  NLL(z)=-0.2535  R²=0.9703\n",
            "[14/20] best_ep= 48  RMSE=2.9554  NLL(z)=-0.1962  R²=0.9697\n",
            "[15/20] best_ep= 23  RMSE=3.3003  NLL(z)=-0.2281  R²=0.9610\n",
            "[16/20] best_ep= 18  RMSE=2.9096  NLL(z)=-0.2048  R²=0.9713\n",
            "[17/20] best_ep= 74  RMSE=3.0391  NLL(z)=-0.1614  R²=0.9683\n",
            "[18/20] best_ep= 18  RMSE=3.4603  NLL(z)=-0.1520  R²=0.9598\n",
            "[19/20] best_ep= 50  RMSE=3.3966  NLL(z)=-0.2332  R²=0.9612\n",
            "[20/20] best_ep= 19  RMSE=3.0447  NLL(z)=-0.2113  R²=0.9674\n",
            "\n",
            "== MoE (Router Ablation, with Anchor + CAL, no-leak) on Housing ==\n",
            "RMSE = 3.2153 ± 0.0514\n",
            "NLL  = -0.1574 ± 0.0171\n"
          ]
        }
      ],
      "source": [
        "set_seed(1)\n",
        "X, y = load_dataset(\"power\", verbose=True)\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "rng = np.random.RandomState(1)\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    splits.append((perm[: round(n*0.9)], perm[round(n*0.9):]))\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for i, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_router_off(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        D=2, K=8, HID=128, NC=3,\n",
        "        LR=1e-3, EPOCHS=400,\n",
        "        SIGMA_MIN=5e-2, SIGMA_MAX=1.0,\n",
        "        TOPK=2, SMOOTH_EPS=0.05,\n",
        "        seed=1+i\n",
        "    )\n",
        "    print(f\"[{i:02d}/20] best_ep={best_ep:3d}  RMSE={rmse:.4f}  NLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "\n",
        "rmses = np.array(rmses); nlls = np.array(nlls)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "print(\"\\n== MoE (Router Ablation, with Anchor + CAL, no-leak) on Housing ==\")\n",
        "print(f\"RMSE = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3vxjMaXUarFS"
      },
      "source": [
        "## Protein"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "N0CLnV5Naqop",
        "outputId": "d809a457-55b5-4e41-a281-1f2a96a2c3e9"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Protein(local)] X.shape=(45730, 9) y.shape=(45730,) | y∈[0.000,20.999]\n",
            "[Protein] downsampled: 45730 -> 10000\n",
            "[01/20] best_ep=  6  RMSE=4.2650  NLL(z)=1.0274  R²=0.5230\n",
            "[02/20] best_ep=  7  RMSE=4.4836  NLL(z)=1.1656  R²=0.4544\n",
            "[03/20] best_ep=  6  RMSE=4.3765  NLL(z)=1.1102  R²=0.5026\n",
            "[04/20] best_ep= 87  RMSE=4.3210  NLL(z)=0.8869  R²=0.5160\n",
            "[05/20] best_ep=  6  RMSE=4.4092  NLL(z)=1.0418  R²=0.4870\n",
            "[06/20] best_ep= 50  RMSE=4.4048  NLL(z)=0.9372  R²=0.4621\n",
            "[07/20] best_ep=  5  RMSE=4.5262  NLL(z)=1.1422  R²=0.4557\n",
            "[08/20] best_ep=  4  RMSE=4.4356  NLL(z)=1.1459  R²=0.4610\n",
            "[09/20] best_ep=  5  RMSE=4.2557  NLL(z)=1.0701  R²=0.4922\n",
            "[10/20] best_ep=  5  RMSE=4.3820  NLL(z)=1.0539  R²=0.5069\n",
            "[11/20] best_ep=  3  RMSE=4.3189  NLL(z)=1.0704  R²=0.4725\n",
            "[12/20] best_ep=  5  RMSE=4.4878  NLL(z)=1.1241  R²=0.4558\n",
            "[13/20] best_ep= 70  RMSE=4.4511  NLL(z)=1.0055  R²=0.4690\n",
            "[14/20] best_ep=  6  RMSE=4.3288  NLL(z)=1.1723  R²=0.5186\n",
            "[15/20] best_ep=129  RMSE=4.4782  NLL(z)=0.7762  R²=0.4731\n",
            "[16/20] best_ep=  5  RMSE=4.4936  NLL(z)=1.0984  R²=0.4757\n",
            "[17/20] best_ep=  5  RMSE=4.6575  NLL(z)=1.1721  R²=0.4300\n",
            "[18/20] best_ep=  7  RMSE=4.4965  NLL(z)=1.0885  R²=0.4830\n",
            "[19/20] best_ep= 71  RMSE=4.6680  NLL(z)=0.8513  R²=0.4312\n",
            "[20/20] best_ep=  6  RMSE=4.2553  NLL(z)=1.0109  R²=0.5125\n",
            "\n",
            "== MoE (Router Ablation, with Anchor + CAL, no-leak) on Protein (10k) ==\n",
            "RMSE = 4.4248 ± 0.0263\n",
            "NLL  = 1.0476 ± 0.0247\n"
          ]
        }
      ],
      "source": [
        "# === Run on Protein (10k downsample) — Router Ablation (with Anchor + CAL, no-leak) ===\n",
        "set_seed(1)\n",
        "\n",
        "\n",
        "X, y = load_dataset(\"protein\", verbose=True)\n",
        "\n",
        "MAX_N = 10_000\n",
        "rng_ds = np.random.RandomState(1)  \n",
        "n_full = X.shape[0]\n",
        "if n_full > MAX_N:\n",
        "    keep_idx = rng_ds.choice(n_full, size=MAX_N, replace=False)\n",
        "    X = X[keep_idx].copy()\n",
        "    y = y[keep_idx].copy()\n",
        "print(f\"[Protein] downsampled: {n_full} -> {X.shape[0]}\")\n",
        "\n",
        "n = X.shape[0]\n",
        "rng = np.random.RandomState(1)\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS    = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for i, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_router_off(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=1 + i\n",
        "    )\n",
        "    print(f\"[{i:02d}/20] best_ep={best_ep:3d}  RMSE={rmse:.4f}  NLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses = np.array(rmses); nlls = np.array(nlls)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== MoE (Router Ablation, with Anchor + CAL, no-leak) on Protein (10k) ==\")\n",
        "print(f\"RMSE = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nC5Dx5ZdFtrX"
      },
      "source": [
        "## Wine"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "j39pkIZql0vN",
        "outputId": "73bf602d-5e4b-4e79-fb26-4afff11328dc"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[WineRed] X.shape=(1599, 11) y.shape=(1599,) | y∈[3.000,8.000]\n",
            "[01/20] best_ep=  3  RMSE=0.6191  NLL(z)=1.1486  R²=0.4779\n",
            "[02/20] best_ep=  4  RMSE=0.5933  NLL(z)=1.1272  R²=0.4309\n",
            "[03/20] best_ep=  3  RMSE=0.7340  NLL(z)=1.4439  R²=0.3934\n",
            "[04/20] best_ep=  2  RMSE=0.5916  NLL(z)=1.1027  R²=0.4483\n",
            "[05/20] best_ep=  4  RMSE=0.5465  NLL(z)=1.0120  R²=0.4642\n",
            "[06/20] best_ep=  4  RMSE=0.6048  NLL(z)=1.1110  R²=0.4296\n",
            "[07/20] best_ep=  2  RMSE=0.6740  NLL(z)=1.2361  R²=0.4062\n",
            "[08/20] best_ep=  3  RMSE=0.6553  NLL(z)=1.2097  R²=0.4387\n",
            "[09/20] best_ep=  2  RMSE=0.5934  NLL(z)=1.0664  R²=0.4281\n",
            "[10/20] best_ep=  3  RMSE=0.6312  NLL(z)=1.2219  R²=0.4748\n",
            "[11/20] best_ep=  4  RMSE=0.6203  NLL(z)=1.2297  R²=0.4588\n",
            "[12/20] best_ep= 20  RMSE=0.6549  NLL(z)=1.1588  R²=0.2508\n",
            "[13/20] best_ep=  3  RMSE=0.6458  NLL(z)=1.2793  R²=0.3456\n",
            "[14/20] best_ep=  4  RMSE=0.6289  NLL(z)=1.1836  R²=0.2503\n",
            "[15/20] best_ep=  5  RMSE=0.5853  NLL(z)=1.0740  R²=0.5255\n",
            "[16/20] best_ep= 11  RMSE=0.5634  NLL(z)=1.0104  R²=0.4300\n",
            "[17/20] best_ep=  5  RMSE=0.6042  NLL(z)=1.1254  R²=0.4866\n",
            "[18/20] best_ep=  6  RMSE=0.5796  NLL(z)=1.1948  R²=0.5204\n",
            "[19/20] best_ep=  4  RMSE=0.5903  NLL(z)=1.2011  R²=0.4529\n",
            "[20/20] best_ep=  3  RMSE=0.5945  NLL(z)=1.1445  R²=0.4848\n",
            "\n",
            "== MoE (Router Ablation, with Anchor + CAL, no-leak) on Housing ==\n",
            "RMSE = 0.6155 ± 0.0095\n",
            "NLL  = 1.1641 ± 0.0220\n"
          ]
        }
      ],
      "source": [
        "set_seed(1)\n",
        "X, y = load_dataset(\"wine\", verbose=True)\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "rng = np.random.RandomState(1)\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    splits.append((perm[: round(n*0.9)], perm[round(n*0.9):]))\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for i, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_router_off(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        D=2, K=8, HID=128, NC=3,\n",
        "        LR=1e-3, EPOCHS=400,\n",
        "        SIGMA_MIN=5e-2, SIGMA_MAX=1.0,\n",
        "        TOPK=2, SMOOTH_EPS=0.05,\n",
        "        seed=1+i\n",
        "    )\n",
        "    print(f\"[{i:02d}/20] best_ep={best_ep:3d}  RMSE={rmse:.4f}  NLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "\n",
        "rmses = np.array(rmses); nlls = np.array(nlls)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "print(\"\\n== MoE (Router Ablation, with Anchor + CAL, no-leak) on Housing ==\")\n",
        "print(f\"RMSE = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6-TbIb0AnpGR"
      },
      "source": [
        "## Yacht"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "XeDyJS0qnsDi",
        "outputId": "4a554673-5b73-42be-f310-b09ef59a238e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Yacht] X.shape=(308, 6) y.shape=(308,) | y∈[0.010,62.420]\n",
            "[01/20] best_ep=283  RMSE=0.2997  NLL(z)=-1.9975  R²=0.9996\n",
            "[02/20] best_ep=137  RMSE=0.5804  NLL(z)=-1.7256  R²=0.9983\n",
            "[03/20] best_ep=289  RMSE=0.4620  NLL(z)=-1.9131  R²=0.9983\n",
            "[04/20] best_ep=170  RMSE=0.7143  NLL(z)=-1.7618  R²=0.9984\n",
            "[05/20] best_ep=362  RMSE=0.5734  NLL(z)=-1.8615  R²=0.9989\n",
            "[06/20] best_ep=249  RMSE=0.5007  NLL(z)=-1.8520  R²=0.9981\n",
            "[07/20] best_ep=370  RMSE=0.2688  NLL(z)=-2.0164  R²=0.9982\n",
            "[08/20] best_ep=345  RMSE=0.7860  NLL(z)=-1.5648  R²=0.9971\n",
            "[09/20] best_ep=137  RMSE=0.5190  NLL(z)=-1.8564  R²=0.9984\n",
            "[10/20] best_ep=390  RMSE=0.4961  NLL(z)=-1.7786  R²=0.9972\n",
            "[11/20] best_ep=391  RMSE=0.3863  NLL(z)=-1.9419  R²=0.9995\n",
            "[12/20] best_ep=246  RMSE=0.7766  NLL(z)=-1.6490  R²=0.9960\n",
            "[13/20] best_ep=122  RMSE=0.9339  NLL(z)=-1.5199  R²=0.9967\n",
            "[14/20] best_ep=248  RMSE=0.8091  NLL(z)=-1.5349  R²=0.9975\n",
            "[15/20] best_ep=289  RMSE=0.6055  NLL(z)=-1.7349  R²=0.9985\n",
            "[16/20] best_ep= 47  RMSE=0.7278  NLL(z)=-1.5564  R²=0.9973\n",
            "[17/20] best_ep=367  RMSE=0.5095  NLL(z)=-1.9212  R²=0.9985\n",
            "[18/20] best_ep=117  RMSE=0.7824  NLL(z)=-1.6695  R²=0.9970\n",
            "[19/20] best_ep= 88  RMSE=0.7997  NLL(z)=-1.6871  R²=0.9974\n",
            "[20/20] best_ep=255  RMSE=0.8211  NLL(z)=-1.6827  R²=0.9981\n",
            "\n",
            "== MoE (Router Ablation, with Anchor + CAL, no-leak) on Housing ==\n",
            "RMSE = 0.6176 ± 0.0419\n",
            "NLL  = -1.7613 ± 0.0346\n"
          ]
        }
      ],
      "source": [
        "set_seed(1)\n",
        "X, y = load_dataset(\"yacht\", verbose=True)\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "rng = np.random.RandomState(1)\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    splits.append((perm[: round(n*0.9)], perm[round(n*0.9):]))\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for i, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, r2 = train_one_split_router_off(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        D=2, K=8, HID=128, NC=3,\n",
        "        LR=1e-3, EPOCHS=400,\n",
        "        SIGMA_MIN=5e-2, SIGMA_MAX=1.0,\n",
        "        TOPK=2, SMOOTH_EPS=0.05,\n",
        "        seed=1+i\n",
        "    )\n",
        "    print(f\"[{i:02d}/20] best_ep={best_ep:3d}  RMSE={rmse:.4f}  NLL(z)={nll:.4f}  R²={r2:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "\n",
        "rmses = np.array(rmses); nlls = np.array(nlls)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "print(\"\\n== MoE (Router Ablation, with Anchor + CAL, no-leak) on Housing ==\")\n",
        "print(f\"RMSE = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PHuPXhPW6BqF"
      },
      "source": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Id2izJGi4RRR"
      },
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "gpuType": "V6E1",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
