{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RahY6KzUJ-sO"
      },
      "source": [
        "# **Ablating Anchor**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xA9oaTjMNUQi"
      },
      "source": [
        "## Model Definition"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "JbFcmESdKPjM"
      },
      "outputs": [],
      "source": [
        "\n",
        "import os, math, gc\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "from typing import Tuple, Dict, List, Optional\n",
        "\n",
        "from sklearn.model_selection import train_test_split\n",
        "from sklearn.metrics import mean_squared_error, r2_score\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "\n",
        "# ---------- Globals ----------\n",
        "LOG2PI  = math.log(2*math.pi)\n",
        "DEVICE  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "torch.set_default_dtype(torch.float32)\n",
        "\n",
        "# ---------- Utilities ----------\n",
        "def set_seed(seed: int = 1) -> None:\n",
        "    np.random.seed(seed); torch.manual_seed(seed)\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "        torch.backends.cudnn.deterministic = True\n",
        "        torch.backends.cudnn.benchmark = False\n",
        "\n",
        "def rmse_score(y_true, y_pred) -> float:\n",
        "    return float(np.sqrt(mean_squared_error(y_true, y_pred)))\n",
        "\n",
        "def zscore_fit(y_train: np.ndarray) -> Tuple[float, float]:\n",
        "    my = float(np.mean(y_train)); sy = float(np.std(y_train) + 1e-8)\n",
        "    return my, sy\n",
        "\n",
        "def to_tensor(x: np.ndarray, yz: Optional[np.ndarray] = None,\n",
        "              mx: Optional[np.ndarray] = None, sx: Optional[np.ndarray] = None):\n",
        "    if mx is None: mx = 0.0\n",
        "    if sx is None: sx = 1.0\n",
        "    Xt = torch.tensor(((x - mx)/sx).astype(np.float32), device=DEVICE)\n",
        "    if yz is None: return Xt\n",
        "    Yt = torch.tensor(yz.astype(np.float32), device=DEVICE)\n",
        "    return Xt, Yt\n",
        "\n",
        "def make_random_splits(n: int, train_frac: float = 0.9, n_splits: int = 20,\n",
        "                       seed: int = 1) -> List[Tuple[np.ndarray, np.ndarray]]:\n",
        "    rng = np.random.RandomState(seed)\n",
        "    splits = []\n",
        "    for _ in range(n_splits):\n",
        "        perm = rng.permutation(n)\n",
        "        tr = perm[: round(n*train_frac)]\n",
        "        te = perm[round(n*train_frac):]\n",
        "        splits.append((tr, te))\n",
        "    return splits\n",
        "\n",
        "def _clean_numeric(df: pd.DataFrame) -> pd.DataFrame:\n",
        "    df = df.apply(pd.to_numeric, errors=\"coerce\")\n",
        "    df = df.replace([np.inf, -np.inf], np.nan).dropna(axis=0)\n",
        "    return df\n",
        "\n",
        "# ---------- Dataset loaders ----------\n",
        "def load_dataset(name: str, verbose: bool = False, **kwargs) -> Tuple[np.ndarray, np.ndarray]:\n",
        "    key = name.lower().strip()\n",
        "\n",
        "    if key == \"housing\":\n",
        "        df = pd.read_csv(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data\",\n",
        "            header=None, sep=r\"\\s+\", engine=\"python\", comment=\"#\", skip_blank_lines=True\n",
        "        )\n",
        "        if df.shape[1] == 1:  # some mirrors pack all columns into one\n",
        "            df = df[0].astype(str).str.strip().str.split(r\"\\s+\", expand=True)\n",
        "        df = _clean_numeric(df); assert df.shape[1] >= 14\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "\n",
        "    elif key == \"concrete\":\n",
        "        df = pd.read_excel(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/concrete/compressive/Concrete_Data.xls\"\n",
        "        )\n",
        "        df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "\n",
        "    elif key == \"energy\":\n",
        "        df = pd.read_excel(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/00242/ENB2012_data.xlsx\"\n",
        "        ).iloc[:, :-1]\n",
        "        df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64)  # features: all but last\n",
        "        y = df.iloc[:, -1].to_numpy(np.float64)   # target: last (Y2)\n",
        "\n",
        "    elif key == \"wine\":\n",
        "        df = pd.read_csv(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv\",\n",
        "            delimiter=\";\"\n",
        "        )\n",
        "        df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "\n",
        "    elif key == \"yacht\":\n",
        "        df = pd.read_csv(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/00243/yacht_hydrodynamics.data\",\n",
        "            header=None, sep=r\"\\s+\", engine=\"python\", comment=\"#\", skip_blank_lines=True\n",
        "        )\n",
        "        if df.shape[1] == 1:\n",
        "            df = df[0].astype(str).str.strip().str.split(r\"\\s+\", expand=True)\n",
        "        df = _clean_numeric(df); assert df.shape[1] >= 7\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "\n",
        "    elif key == \"kin8nm\":\n",
        "        local = kwargs.get(\"path\", \"kin8nm.csv\")\n",
        "        if not os.path.exists(local): raise FileNotFoundError(f\"[kin8nm] missing file: {local}\")\n",
        "        df = pd.read_csv(local); df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "\n",
        "    elif key == \"protein\":\n",
        "        local = kwargs.get(\"path\", \"protein.csv\")\n",
        "        if not os.path.exists(local): raise FileNotFoundError(f\"[protein] missing file: {local}\")\n",
        "        df = pd.read_csv(local, sep=None, engine=\"python\")\n",
        "        df = _clean_numeric(df)\n",
        "        y = df.iloc[:, 0].to_numpy(np.float64)\n",
        "        X = df.iloc[:, 1:].to_numpy(np.float64)\n",
        "\n",
        "    elif key == \"naval\":\n",
        "        local = kwargs.get(\"path\", \"naval.txt\")\n",
        "        if not os.path.exists(local): raise FileNotFoundError(f\"[naval] missing file: {local}\")\n",
        "        df = pd.read_csv(local, sep=r\"\\s+\", header=None, engine=\"python\")\n",
        "        df = df.iloc[:, :-1]  # last column often duplicate/unused\n",
        "        df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64)\n",
        "        y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "\n",
        "    elif key == \"power\":\n",
        "        local = kwargs.get(\"path\", \"power.xlsx\")\n",
        "        if not os.path.exists(local): raise FileNotFoundError(f\"[power] missing file: {local}\")\n",
        "        df = pd.read_excel(local)\n",
        "        df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64)\n",
        "        y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "\n",
        "    elif key == \"msd\":\n",
        "        local = kwargs.get(\"path\", \"YearPredictionMSD.txt\")\n",
        "        if not os.path.exists(local): raise FileNotFoundError(f\"[msd] missing file: {local}\")\n",
        "        df = pd.read_csv(local, header=None)\n",
        "        df = df.iloc[:, ::-1]  # make target last\n",
        "        df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64)\n",
        "        y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "\n",
        "    else:\n",
        "        raise ValueError(\"Unknown dataset name.\")\n",
        "\n",
        "    if verbose:\n",
        "        print(f\"[{name}] X.shape={X.shape} y.shape={y.shape} | y∈[{np.min(y):.3f},{np.max(y):.3f}]\")\n",
        "    return X, y\n",
        "\n",
        "# ---------- MoE (No-Anchor) ----------\n",
        "def _topk_mask(w: torch.Tensor, k: int = 2) -> torch.Tensor:\n",
        "    _, topi = torch.topk(w, k, dim=-1)\n",
        "    mask = torch.zeros_like(w).scatter_(-1, topi, 1.0)\n",
        "    w2 = w * mask\n",
        "    return w2 / (w2.sum(dim=-1, keepdim=True) + 1e-12)\n",
        "\n",
        "def _topk_mask_smooth(w: torch.Tensor, k: int = 2, eps: float = 0.05) -> torch.Tensor:\n",
        "    _, topi = torch.topk(w, k, dim=-1)\n",
        "    mask = torch.zeros_like(w).scatter_(-1, topi, 1.0)\n",
        "    w_top = w * mask\n",
        "    w_top = w_top / (w_top.sum(dim=-1, keepdim=True) + 1e-12)\n",
        "    return (1.0 - eps) * w_top + (eps / k) * mask\n",
        "\n",
        "class Projection(nn.Module):\n",
        "    def __init__(self, d: int, D: int):\n",
        "        super().__init__()\n",
        "        self.w = nn.Linear(d, D, bias=True)\n",
        "        nn.init.xavier_uniform_(self.w.weight); nn.init.zeros_(self.w.bias)\n",
        "    def forward(self, x): return self.w(x)\n",
        "\n",
        "class Window(nn.Module):\n",
        "    def __init__(self, K: int, D: int, min_log_s: float = -2.5, max_log_s: float = 1.0):\n",
        "        super().__init__()\n",
        "        self.c         = nn.Parameter(torch.randn(K, D))\n",
        "        self.log_s     = nn.Parameter(torch.zeros(K, D))\n",
        "        self.min_log_s = min_log_s; self.max_log_s = max_log_s\n",
        "    def forward(self, z):\n",
        "        log_s = torch.clamp(self.log_s, min=self.min_log_s, max=self.max_log_s)\n",
        "        diff2 = ((z[:, None] - self.c)**2) / (2 * torch.exp(log_s)**2)\n",
        "        return torch.exp(-diff2.sum(dim=-1)) + 1e-12  # [B,K]\n",
        "\n",
        "class Router(nn.Module):\n",
        "    def __init__(self, D: int, K: int):\n",
        "        super().__init__()\n",
        "        self.q   = nn.Linear(D, 64)\n",
        "        self.k   = nn.Parameter(torch.randn(K, 64))\n",
        "        self.tau = 3.0\n",
        "    def forward(self, z):\n",
        "        logits = (self.q(z) @ self.k.T) / math.sqrt(64)\n",
        "        return F.softmax(logits / self.tau, dim=-1)\n",
        "\n",
        "class ExpertMDN(nn.Module):\n",
        "    def __init__(self, d: int, h: int, nc: int, sigma_min: float = 5e-2, sigma_max: float = 1.0):\n",
        "        super().__init__()\n",
        "        self.net = nn.Sequential(\n",
        "            nn.Linear(d, h), nn.ReLU(),\n",
        "            nn.Linear(h, h), nn.ReLU()\n",
        "        )\n",
        "        self.logits = nn.Linear(h, nc)\n",
        "        self.means  = nn.Linear(h, nc)\n",
        "        self.log_sc = nn.Linear(h, nc)\n",
        "        self.sigma_min = float(sigma_min)\n",
        "        self.sigma_max = float(sigma_max)\n",
        "        for m in self.modules():\n",
        "            if isinstance(m, nn.Linear):\n",
        "                nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)\n",
        "    def forward(self, x):\n",
        "        h  = self.net(x)\n",
        "        pi = F.softmax(self.logits(h), dim=-1)\n",
        "        mu = self.means(h)\n",
        "        sg = torch.exp(self.log_sc(h)).clamp(self.sigma_min, self.sigma_max)\n",
        "        return pi, mu, sg\n",
        "\n",
        "class BLRMoE(nn.Module):\n",
        "  \n",
        "    def __init__(self, d: int, D: int, K: int, hid: int, nc: int,\n",
        "                 w_ent_warm: float = +1e-3, w_ent_cool: float = -2e-4,\n",
        "                 l2_win: float = 1e-4, lb_coef: float = 1e-3,\n",
        "                 sigma_min: float = 5e-2, sigma_max: float = 1.0,\n",
        "                 topk: int = 2, smooth_eps: float = 0.05):\n",
        "        super().__init__()\n",
        "        self.proj   = Projection(d, D)\n",
        "        self.win    = Window(K, D)\n",
        "        self.router = Router(D, K)\n",
        "        self.exps   = nn.ModuleList([ExpertMDN(d, hid, nc, sigma_min, sigma_max) for _ in range(K)])\n",
        "        self.w_ent_warm = w_ent_warm\n",
        "        self.w_ent_cool = w_ent_cool\n",
        "        self.l2_win     = l2_win\n",
        "        self.lb_coef    = lb_coef\n",
        "        self.topk       = topk\n",
        "        self.smooth_eps = smooth_eps\n",
        "        self.K = K; self.nc = nc\n",
        "\n",
        "    def _mixture_params(self, X: torch.Tensor, train: bool = True):\n",
        "        z = self.proj(X)\n",
        "        w = self.win(z) * self.router(z)\n",
        "        w = w / (w.sum(dim=-1, keepdim=True) + 1e-12)\n",
        "        w = _topk_mask_smooth(w, k=self.topk, eps=self.smooth_eps) if train else _topk_mask(w, k=self.topk)\n",
        "\n",
        "        B, K, C = X.size(0), self.K, self.nc\n",
        "        Pi = torch.full((B, K, C), 1.0/C, device=X.device)\n",
        "        Mu = torch.zeros(B, K, C, device=X.device)\n",
        "        Sg = torch.ones(B, K, C, device=X.device)\n",
        "\n",
        "        _, topi = torch.topk(w, self.topk, dim=-1)\n",
        "        uniq = torch.unique(topi)\n",
        "        for j in uniq.tolist():\n",
        "            pi_j, mu_j, sg_j = self.exps[j](X)\n",
        "            mask = (topi == j).any(dim=1).float().unsqueeze(-1)\n",
        "            Pi[:, j, :] = pi_j * mask + Pi[:, j, :]*(1 - mask)\n",
        "            Mu[:, j, :] = mu_j * mask + Mu[:, j, :]*(1 - mask)\n",
        "            Sg[:, j, :] = sg_j * mask + Sg[:, j, :]*(1 - mask)\n",
        "        return w, Pi, Mu, Sg, z\n",
        "\n",
        "    # keep signature; mu_anchor_z is ignored\n",
        "    def nll(self, X: torch.Tensor, y_z: torch.Tensor, mu_anchor_z=None,\n",
        "            epoch: int = 1, warmup_epochs: int = 150) -> torch.Tensor:\n",
        "        train_flag = self.training\n",
        "        w, Pi, Mu, Sg, z = self._mixture_params(X, train=train_flag)\n",
        "\n",
        "        Mu_eff = Mu  # no-anchor\n",
        "        yv   = y_z[:, None, None]\n",
        "        logp = -0.5 * ((yv - Mu_eff)/Sg)**2 - torch.log(Sg) - 0.5*LOG2PI\n",
        "\n",
        "        w3   = w[:, :, None]\n",
        "        logw = torch.where(w3 > 0, torch.log(w3 + 1e-12), torch.full_like(w3, -1e9))\n",
        "        logpi= torch.log(Pi + 1e-12)\n",
        "        logmix = torch.logsumexp(logw + logpi + logp, dim=(1,2))\n",
        "        nll = -logmix.mean()\n",
        "\n",
        "        if train_flag:\n",
        "            w_ent = self.w_ent_warm if epoch <= warmup_epochs else self.w_ent_cool\n",
        "            p = self.router(z)\n",
        "            ent = (p * torch.log(p + 1e-12)).sum(dim=1).mean()\n",
        "            l2w = (self.win.log_s**2).mean()\n",
        "            rho = p.mean(dim=0)\n",
        "            lb_loss = ((rho - 1.0/p.size(1))**2).sum()\n",
        "            nll = nll + w_ent*ent + self.l2_win*l2w + self.lb_coef*lb_loss\n",
        "        return nll\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def predict_mean_var(self, X: torch.Tensor, mu_anchor_z=None):\n",
        "        w, Pi, Mu, Sg, _ = self._mixture_params(X, train=False)\n",
        "        Mu_eff = Mu\n",
        "        mu_e  = (Pi * Mu_eff).sum(dim=2)\n",
        "        m2_e  = (Pi * (Sg**2 + Mu_eff**2)).sum(dim=2)\n",
        "        mu_z  = (w * mu_e).sum(dim=1)\n",
        "        second= (w * m2_e).sum(dim=1)\n",
        "        var_z = torch.clamp(second - mu_z**2, min=1e-9)\n",
        "        return mu_z, var_z\n",
        "\n",
        "# ---------- Training/Eval (Leak-free, No-Anchor) ----------\n",
        "def train_one_split_no_anchor(\n",
        "    X_all: np.ndarray, y_all: np.ndarray,    # 90% train-side (outer split)\n",
        "    X_te:  np.ndarray, y_te:  np.ndarray,    # 10% test-side (outer split, only for final eval)\n",
        "    standardize_x: bool = True,\n",
        "    D: int = 4, K: int = 8, HID: int = 128, NC: int = 3,\n",
        "    LR: float = 1e-3, EPOCHS: int = 400,\n",
        "    SIGMA_MIN: float = 5e-2, SIGMA_MAX: float = 1.0,\n",
        "    TOPK: int = 2, SMOOTH_EPS: float = 0.05,\n",
        "    seed: int = 1\n",
        ") -> Tuple[float, float, int, float]:\n",
        "\n",
        "    # TV/CAL split (CAL only for linear calibration)\n",
        "    X_tv, X_cal, y_tv, y_cal = train_test_split(X_all, y_all, test_size=0.125, random_state=seed)\n",
        "    # TR/VA split for early stopping\n",
        "    X_tr, X_va, y_tr, y_va   = train_test_split(X_tv,  y_tv,  test_size=0.2,   random_state=seed)\n",
        "\n",
        "    # Feature standardization (fit on TR only)\n",
        "    if standardize_x:\n",
        "        mx_tr = X_tr.mean(0, keepdims=True); sx_tr = X_tr.std(0, keepdims=True) + 1e-8\n",
        "    else:\n",
        "        mx_tr = np.zeros((1, X_tr.shape[1])); sx_tr = np.ones((1, X_tr.shape[1]))\n",
        "\n",
        "    # y z-score (Phase-1 on TR stats)\n",
        "    my_tr, sy_tr = zscore_fit(y_tr)\n",
        "    y_tr_z = (y_tr - my_tr) / sy_tr\n",
        "    y_va_z = (y_va - my_tr) / sy_tr\n",
        "\n",
        "    # Tensors\n",
        "    Xtr_t, ytr_t = to_tensor(X_tr, y_tr_z, mx_tr, sx_tr)\n",
        "    Xva_t, yva_t = to_tensor(X_va, y_va_z, mx_tr, sx_tr)\n",
        "\n",
        "    # Model\n",
        "    model = BLRMoE(\n",
        "        d=X_tr.shape[1], D=D, K=K, hid=HID, nc=NC,\n",
        "        sigma_min=SIGMA_MIN, sigma_max=SIGMA_MAX,\n",
        "        topk=TOPK, smooth_eps=SMOOTH_EPS\n",
        "    ).to(DEVICE)\n",
        "    opt = optim.AdamW(model.parameters(), lr=LR, weight_decay=3e-4)\n",
        "\n",
        "    # Early stopping on VA-NLL\n",
        "    best_ep, best_vnll, best_state = 0, +1e9, None\n",
        "    for ep in range(1, EPOCHS+1):\n",
        "        model.train()\n",
        "        opt.zero_grad()\n",
        "        loss = model.nll(Xtr_t, ytr_t, mu_anchor_z=None, epoch=ep, warmup_epochs=150)\n",
        "        loss.backward(); nn.utils.clip_grad_norm_(model.parameters(), 2.0); opt.step()\n",
        "        model.router.tau = max(model.router.tau * 0.995, 1.0)\n",
        "\n",
        "        model.eval()\n",
        "        with torch.no_grad():\n",
        "            vnll = float(model.nll(Xva_t, yva_t, mu_anchor_z=None, epoch=ep, warmup_epochs=150).cpu().item())\n",
        "        if vnll < best_vnll:\n",
        "            best_vnll = vnll; best_ep = ep\n",
        "            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}\n",
        "\n",
        "    # Phase-2 retrain on TV for best_ep (y z-score now uses TV stats; features still TR stats)\n",
        "    my_tv, sy_tv = zscore_fit(y_tv)\n",
        "    X_tv_t, y_tv_t = to_tensor(X_tv, (y_tv - my_tv)/sy_tv, mx_tr, sx_tr)\n",
        "\n",
        "    model.load_state_dict(best_state)\n",
        "    model.train()\n",
        "    opt = optim.AdamW(model.parameters(), lr=LR, weight_decay=3e-4)\n",
        "    for ep in range(1, best_ep+1):\n",
        "        opt.zero_grad()\n",
        "        loss = model.nll(X_tv_t, y_tv_t, mu_anchor_z=None, epoch=ep, warmup_epochs=150)\n",
        "        loss.backward(); nn.utils.clip_grad_norm_(model.parameters(), 2.0); opt.step()\n",
        "        model.router.tau = max(model.router.tau * 0.995, 1.0)\n",
        "\n",
        "    # Evaluate: CAL-only linear calibration; TEST-only metrics\n",
        "    model.eval()\n",
        "    with torch.no_grad():\n",
        "        Xcal_t = to_tensor(X_cal, None, mx_tr, sx_tr)\n",
        "        Xte_t  = to_tensor(X_te,  None, mx_tr, sx_tr)\n",
        "\n",
        "        # NLL on TEST (using TV stats)\n",
        "        y_te_z = (y_te - my_tv)/sy_tv\n",
        "        yte_t  = torch.tensor(y_te_z.astype(np.float32), device=DEVICE)\n",
        "        test_nll = float(model.nll(Xte_t, yte_t, mu_anchor_z=None).cpu().item())\n",
        "\n",
        "        # Means back to original units using TV stats\n",
        "        mu_z_cal, _ = model.predict_mean_var(Xcal_t, mu_anchor_z=None)\n",
        "        mu_cal_orig = mu_z_cal.cpu().numpy().astype(np.float64) * sy_tv + my_tv\n",
        "\n",
        "        mu_z_te,  _ = model.predict_mean_var(Xte_t,  mu_anchor_z=None)\n",
        "        mu_te_orig  = mu_z_te.cpu().numpy().astype(np.float64)  * sy_tv + my_tv\n",
        "\n",
        "    # Linear post-calibration on CAL\n",
        "    A = np.vstack([mu_cal_orig, np.ones_like(mu_cal_orig)]).T\n",
        "    ab, *_ = np.linalg.lstsq(A, y_cal.astype(np.float64), rcond=None)\n",
        "    a, b = float(ab[0]), float(ab[1])\n",
        "\n",
        "    mu_te_cal = a * mu_te_orig + b\n",
        "    rmse = rmse_score(y_te.astype(np.float64), mu_te_cal)\n",
        "    r2   = float(r2_score(y_te.astype(np.float64), mu_te_cal))\n",
        "    return rmse, test_nll, best_ep, r2\n",
        "\n",
        "def run_no_anchor_experiment(\n",
        "    X: np.ndarray, y: np.ndarray,\n",
        "    n_splits: int = 20, outer_train_frac: float = 0.9, seed: int = 1,\n",
        "    standardize_x: bool = True,\n",
        "    D: int = 4, K: int = 8, HID: int = 128, NC: int = 3,\n",
        "    LR: float = 1e-3, EPOCHS: int = 400,\n",
        "    SIGMA_MIN: float = 5e-2, SIGMA_MAX: float = 1.0,\n",
        "    TOPK: int = 2, SMOOTH_EPS: float = 0.05\n",
        ") -> Dict[str, np.ndarray]:\n",
        "    \"\"\"\n",
        "    Convenience wrapper: run multiple outer splits (leak-free) and return arrays of metrics.\n",
        "    No printing inside; you can print/aggregate in your own cell.\n",
        "    \"\"\"\n",
        "    splits = make_random_splits(len(X), train_frac=outer_train_frac, n_splits=n_splits, seed=seed)\n",
        "    rmses, nlls, r2s, best_eps = [], [], [], []\n",
        "    for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "        rmse, nll, best_ep, r2 = train_one_split_no_anchor(\n",
        "            X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "            X[te_idx].copy(),  y[te_idx].copy(),\n",
        "            standardize_x=standardize_x,\n",
        "            D=D, K=K, HID=HID, NC=NC,\n",
        "            LR=LR, EPOCHS=EPOCHS,\n",
        "            SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "            TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "            seed=seed+itr\n",
        "        )\n",
        "        rmses.append(rmse); nlls.append(nll); r2s.append(r2); best_eps.append(best_ep)\n",
        "        gc.collect();\n",
        "        if torch.cuda.is_available(): torch.cuda.empty_cache()\n",
        "\n",
        "    return {\n",
        "        \"rmse\": np.asarray(rmses, dtype=np.float64),\n",
        "        \"nll\":  np.asarray(nlls,  dtype=np.float64),\n",
        "        \"r2\":   np.asarray(r2s,   dtype=np.float64),\n",
        "        \"best_ep\": np.asarray(best_eps, dtype=np.int32),\n",
        "    }"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ADOLkM54KMqI"
      },
      "source": [
        "## Housing"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gN7D3b1PK8Al",
        "outputId": "0d0ced22-2c97-466e-cdaf-2f55d98ac11f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[housing] X.shape=(506, 13) y.shape=(506,) | y∈[5.000,50.000]\n",
            "[01/20] MoE best_ep= 21  TestRMSE(orig)=2.7858  TestNLL(z)=0.2485\n",
            "[02/20] MoE best_ep= 25  TestRMSE(orig)=2.7184  TestNLL(z)=0.4500\n",
            "[03/20] MoE best_ep= 40  TestRMSE(orig)=5.8159  TestNLL(z)=5.3371\n",
            "[04/20] MoE best_ep= 21  TestRMSE(orig)=3.9054  TestNLL(z)=0.6567\n",
            "[05/20] MoE best_ep= 31  TestRMSE(orig)=5.5873  TestNLL(z)=0.8501\n",
            "[06/20] MoE best_ep= 20  TestRMSE(orig)=3.1320  TestNLL(z)=0.3391\n",
            "[07/20] MoE best_ep= 21  TestRMSE(orig)=2.3982  TestNLL(z)=0.2805\n",
            "[08/20] MoE best_ep= 16  TestRMSE(orig)=4.6720  TestNLL(z)=0.3427\n",
            "[09/20] MoE best_ep= 23  TestRMSE(orig)=3.9425  TestNLL(z)=1.2395\n",
            "[10/20] MoE best_ep= 19  TestRMSE(orig)=5.7289  TestNLL(z)=0.6667\n",
            "[11/20] MoE best_ep= 29  TestRMSE(orig)=5.9790  TestNLL(z)=0.9890\n",
            "[12/20] MoE best_ep= 17  TestRMSE(orig)=5.2189  TestNLL(z)=0.5063\n",
            "[13/20] MoE best_ep= 19  TestRMSE(orig)=3.0202  TestNLL(z)=0.3520\n",
            "[14/20] MoE best_ep= 17  TestRMSE(orig)=4.4091  TestNLL(z)=0.9967\n",
            "[15/20] MoE best_ep= 21  TestRMSE(orig)=3.3780  TestNLL(z)=0.5771\n",
            "[16/20] MoE best_ep= 17  TestRMSE(orig)=2.5204  TestNLL(z)=0.4315\n",
            "[17/20] MoE best_ep= 23  TestRMSE(orig)=3.8297  TestNLL(z)=0.5794\n",
            "[18/20] MoE best_ep= 18  TestRMSE(orig)=3.7827  TestNLL(z)=0.5561\n",
            "[19/20] MoE best_ep= 17  TestRMSE(orig)=3.6376  TestNLL(z)=0.4851\n",
            "[20/20] MoE best_ep= 19  TestRMSE(orig)=6.3140  TestNLL(z)=0.7367\n",
            "\n",
            "== MoE (No-Anchor, no-leak) on Housing ==\n",
            "RMSE (orig) = 4.1388 ± 0.2809\n",
            "NLL  (z)    = 0.8310 ± 0.2443\n"
          ]
        }
      ],
      "source": [
        "# === Run on Housing (UCI Boston) — MoE w/o Anchor (leak-free) ===\n",
        "set_seed(1)\n",
        "X, y = load_dataset(\"housing\", verbose=True)\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = np.random.choice(np.arange(n), n, replace=False)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS    = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, _r2 = train_one_split_no_anchor(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses = np.array(rmses, dtype=np.float64)\n",
        "nlls  = np.array(nlls,  dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== MoE (No-Anchor, no-leak) on Housing ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "02_CZ8ERQbXU"
      },
      "source": [
        "## Concrete"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hRCvFxVDQeWT",
        "outputId": "ca86f167-3f76-4e0d-e8b7-d63d5ce50fb3"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[concrete] X.shape=(1030, 8) y.shape=(1030,) | y∈[2.332,82.599]\n",
            "[01/20] MoE best_ep= 31  TestRMSE(orig)=8.6908  TestNLL(z)=0.8044\n",
            "[02/20] MoE best_ep= 29  TestRMSE(orig)=9.0519  TestNLL(z)=0.8425\n",
            "[03/20] MoE best_ep= 34  TestRMSE(orig)=6.4428  TestNLL(z)=0.5905\n",
            "[04/20] MoE best_ep= 34  TestRMSE(orig)=7.3591  TestNLL(z)=0.7132\n",
            "[05/20] MoE best_ep= 51  TestRMSE(orig)=7.4073  TestNLL(z)=1.1140\n",
            "[06/20] MoE best_ep= 26  TestRMSE(orig)=8.3383  TestNLL(z)=0.6317\n",
            "[07/20] MoE best_ep= 27  TestRMSE(orig)=8.6215  TestNLL(z)=0.7152\n",
            "[08/20] MoE best_ep= 41  TestRMSE(orig)=7.5074  TestNLL(z)=0.8342\n",
            "[09/20] MoE best_ep= 48  TestRMSE(orig)=7.9033  TestNLL(z)=0.4478\n",
            "[10/20] MoE best_ep= 42  TestRMSE(orig)=7.2035  TestNLL(z)=0.6737\n",
            "[11/20] MoE best_ep= 28  TestRMSE(orig)=7.0081  TestNLL(z)=0.6334\n",
            "[12/20] MoE best_ep= 28  TestRMSE(orig)=7.7609  TestNLL(z)=0.6468\n",
            "[13/20] MoE best_ep= 57  TestRMSE(orig)=7.3684  TestNLL(z)=0.6108\n",
            "[14/20] MoE best_ep= 70  TestRMSE(orig)=6.6325  TestNLL(z)=0.9026\n",
            "[15/20] MoE best_ep= 35  TestRMSE(orig)=7.5010  TestNLL(z)=0.4450\n",
            "[16/20] MoE best_ep= 39  TestRMSE(orig)=8.1741  TestNLL(z)=0.9060\n",
            "[17/20] MoE best_ep= 30  TestRMSE(orig)=7.9139  TestNLL(z)=0.7475\n",
            "[18/20] MoE best_ep= 39  TestRMSE(orig)=8.1811  TestNLL(z)=0.8130\n",
            "[19/20] MoE best_ep= 33  TestRMSE(orig)=8.3587  TestNLL(z)=0.7938\n",
            "[20/20] MoE best_ep= 35  TestRMSE(orig)=7.5834  TestNLL(z)=0.7869\n",
            "\n",
            "== MoE (No-Anchor, no-leak) on Concrete ==\n",
            "RMSE (orig) = 7.7504 ± 0.1526\n",
            "NLL  (z)    = 0.7326 ± 0.0354\n"
          ]
        }
      ],
      "source": [
        "# === Run on Housing (UCI Concrete) — MoE w/o Anchor (leak-free) ===\n",
        "set_seed(1)\n",
        "X, y = load_dataset(\"concrete\", verbose=True)\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = np.random.choice(np.arange(n), n, replace=False)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS    = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "\n",
        "    rmse, nll, best_ep, _r2 = train_one_split_no_anchor(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses = np.array(rmses, dtype=np.float64)\n",
        "nlls  = np.array(nlls,  dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== MoE (No-Anchor, no-leak) on Concrete ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "su-fwkeFSJz4"
      },
      "source": [
        "## Energy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 26,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_FPssAVnSMBH",
        "outputId": "e609ded2-886f-4dc3-ae48-262dcabf1a07"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[energy] X.shape=(768, 8) y.shape=(768,) | y∈[6.010,43.100]\n",
            "[01/20] MoE best_ep=352  TestRMSE(orig)=0.6891  TestNLL(z)=-1.6435\n",
            "[02/20] MoE best_ep=399  TestRMSE(orig)=1.7665  TestNLL(z)=-1.1952\n",
            "[03/20] MoE best_ep=400  TestRMSE(orig)=1.0104  TestNLL(z)=-1.5112\n",
            "[04/20] MoE best_ep=291  TestRMSE(orig)=2.0335  TestNLL(z)=-1.2995\n",
            "[05/20] MoE best_ep=359  TestRMSE(orig)=1.3648  TestNLL(z)=-1.0606\n",
            "[06/20] MoE best_ep=399  TestRMSE(orig)=0.7653  TestNLL(z)=-1.4509\n",
            "[07/20] MoE best_ep=397  TestRMSE(orig)=1.4622  TestNLL(z)=-1.2443\n",
            "[08/20] MoE best_ep=342  TestRMSE(orig)=2.2867  TestNLL(z)=-1.0840\n",
            "[09/20] MoE best_ep=396  TestRMSE(orig)=0.8420  TestNLL(z)=-1.4574\n",
            "[10/20] MoE best_ep=355  TestRMSE(orig)=1.0255  TestNLL(z)=-0.7187\n",
            "[11/20] MoE best_ep=194  TestRMSE(orig)=2.1353  TestNLL(z)=-1.2581\n",
            "[12/20] MoE best_ep=344  TestRMSE(orig)=2.1433  TestNLL(z)=-1.4744\n",
            "[13/20] MoE best_ep=396  TestRMSE(orig)=2.2099  TestNLL(z)=-1.2638\n",
            "[14/20] MoE best_ep=317  TestRMSE(orig)=1.2415  TestNLL(z)=-1.2514\n",
            "[15/20] MoE best_ep=398  TestRMSE(orig)=1.7215  TestNLL(z)=-1.3686\n",
            "[16/20] MoE best_ep=395  TestRMSE(orig)=0.6124  TestNLL(z)=-1.6036\n",
            "[17/20] MoE best_ep=386  TestRMSE(orig)=1.4816  TestNLL(z)=-1.0153\n",
            "[18/20] MoE best_ep=398  TestRMSE(orig)=0.8002  TestNLL(z)=-1.5800\n",
            "[19/20] MoE best_ep=298  TestRMSE(orig)=1.4656  TestNLL(z)=-1.3234\n",
            "[20/20] MoE best_ep=316  TestRMSE(orig)=2.5563  TestNLL(z)=-1.1705\n",
            "\n",
            "== MoE (No-Anchor, no-leak) on Concrete ==\n",
            "RMSE (orig) = 1.4807 ± 0.1349\n",
            "NLL  (z)    = -1.2987 ± 0.0507\n"
          ]
        }
      ],
      "source": [
        "# === Run on Housing (UCI Concrete) — MoE w/o Anchor (leak-free) ===\n",
        "set_seed(1)\n",
        "X, y = load_dataset(\"energy\", verbose=True)\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = np.random.choice(np.arange(n), n, replace=False)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS    = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "\n",
        "    rmse, nll, best_ep, _r2 = train_one_split_no_anchor(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses = np.array(rmses, dtype=np.float64)\n",
        "nlls  = np.array(nlls,  dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== MoE (No-Anchor, no-leak) on Concrete ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EIZ3Ib5qTogE"
      },
      "source": [
        "## Kin8nm\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 27,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0tWBGb5qTn8P",
        "outputId": "ba2266b1-a53b-4332-ac02-afc96fae9348"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[kin8nm] X.shape=(8785, 8) y.shape=(8785,) | y∈[0.040,1.459]\n",
            "[01/20] MoE best_ep= 49  TestRMSE(orig)=0.1083  TestNLL(z)=0.7036\n",
            "[02/20] MoE best_ep= 44  TestRMSE(orig)=0.1113  TestNLL(z)=0.6597\n",
            "[03/20] MoE best_ep= 62  TestRMSE(orig)=0.0971  TestNLL(z)=0.6340\n",
            "[04/20] MoE best_ep= 55  TestRMSE(orig)=0.1097  TestNLL(z)=0.7729\n",
            "[05/20] MoE best_ep= 58  TestRMSE(orig)=0.0982  TestNLL(z)=0.6317\n",
            "[06/20] MoE best_ep= 49  TestRMSE(orig)=0.1095  TestNLL(z)=0.6828\n",
            "[07/20] MoE best_ep= 41  TestRMSE(orig)=0.1095  TestNLL(z)=0.4907\n",
            "[08/20] MoE best_ep= 57  TestRMSE(orig)=0.1102  TestNLL(z)=0.7145\n",
            "[09/20] MoE best_ep= 60  TestRMSE(orig)=0.1059  TestNLL(z)=0.8208\n",
            "[10/20] MoE best_ep= 53  TestRMSE(orig)=0.1089  TestNLL(z)=0.6962\n",
            "[11/20] MoE best_ep= 55  TestRMSE(orig)=0.1049  TestNLL(z)=0.7532\n",
            "[12/20] MoE best_ep= 56  TestRMSE(orig)=0.0999  TestNLL(z)=0.6729\n",
            "[13/20] MoE best_ep= 39  TestRMSE(orig)=0.1224  TestNLL(z)=0.6240\n",
            "[14/20] MoE best_ep= 55  TestRMSE(orig)=0.1023  TestNLL(z)=0.7362\n",
            "[15/20] MoE best_ep= 51  TestRMSE(orig)=0.1129  TestNLL(z)=0.6152\n",
            "[16/20] MoE best_ep= 54  TestRMSE(orig)=0.1060  TestNLL(z)=0.7001\n",
            "[17/20] MoE best_ep= 48  TestRMSE(orig)=0.1198  TestNLL(z)=0.7107\n",
            "[18/20] MoE best_ep= 62  TestRMSE(orig)=0.1127  TestNLL(z)=0.7516\n",
            "[19/20] MoE best_ep= 58  TestRMSE(orig)=0.1024  TestNLL(z)=0.5598\n",
            "[20/20] MoE best_ep= 54  TestRMSE(orig)=0.0970  TestNLL(z)=0.6118\n",
            "\n",
            "== MoE (No-Anchor, no-leak) on Concrete ==\n",
            "RMSE (orig) = 0.1074 ± 0.0015\n",
            "NLL  (z)    = 0.6771 ± 0.0172\n"
          ]
        }
      ],
      "source": [
        "# === Run on Housing (UCI Concrete) — MoE w/o Anchor (leak-free) ===\n",
        "set_seed(1)\n",
        "X, y = load_dataset(\"kin8nm\", verbose=True)\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = np.random.choice(np.arange(n), n, replace=False)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS    = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "\n",
        "    rmse, nll, best_ep, _r2 = train_one_split_no_anchor(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses = np.array(rmses, dtype=np.float64)\n",
        "nlls  = np.array(nlls,  dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== MoE (No-Anchor, no-leak) on Concrete ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZJj4QZuyWVtg"
      },
      "source": [
        "## Naval"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 28,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Cbe1UX4iWVVE",
        "outputId": "a46d48a0-7183-496d-8fec-f64f93b13583"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[naval] X.shape=(11934, 16) y.shape=(11934,) | y∈[0.950,1.000]\n",
            "[01/20] MoE best_ep=398  TestRMSE(orig)=0.0028  TestNLL(z)=-0.6314\n",
            "[02/20] MoE best_ep=398  TestRMSE(orig)=0.0020  TestNLL(z)=-0.8896\n",
            "[03/20] MoE best_ep=400  TestRMSE(orig)=0.0019  TestNLL(z)=-0.5113\n",
            "[04/20] MoE best_ep=399  TestRMSE(orig)=0.0023  TestNLL(z)=-0.8911\n",
            "[05/20] MoE best_ep=400  TestRMSE(orig)=0.0022  TestNLL(z)=-0.9101\n",
            "[06/20] MoE best_ep=400  TestRMSE(orig)=0.0026  TestNLL(z)=-0.6032\n",
            "[07/20] MoE best_ep=389  TestRMSE(orig)=0.0027  TestNLL(z)=-0.5455\n",
            "[08/20] MoE best_ep=398  TestRMSE(orig)=0.0021  TestNLL(z)=-0.6815\n",
            "[09/20] MoE best_ep=398  TestRMSE(orig)=0.0038  TestNLL(z)=-0.2339\n",
            "[10/20] MoE best_ep=399  TestRMSE(orig)=0.0018  TestNLL(z)=-0.8764\n",
            "[11/20] MoE best_ep=394  TestRMSE(orig)=0.0022  TestNLL(z)=-0.5710\n",
            "[12/20] MoE best_ep=396  TestRMSE(orig)=0.0023  TestNLL(z)=-0.6328\n",
            "[13/20] MoE best_ep=356  TestRMSE(orig)=0.0025  TestNLL(z)=-0.3788\n",
            "[14/20] MoE best_ep=400  TestRMSE(orig)=0.0020  TestNLL(z)=-0.8213\n",
            "[15/20] MoE best_ep=399  TestRMSE(orig)=0.0023  TestNLL(z)=-0.6941\n",
            "[16/20] MoE best_ep=399  TestRMSE(orig)=0.0028  TestNLL(z)=-0.3238\n",
            "[17/20] MoE best_ep=399  TestRMSE(orig)=0.0040  TestNLL(z)=-0.8296\n",
            "[18/20] MoE best_ep=399  TestRMSE(orig)=0.0024  TestNLL(z)=-0.1082\n",
            "[19/20] MoE best_ep=375  TestRMSE(orig)=0.0034  TestNLL(z)=-0.2948\n",
            "[20/20] MoE best_ep=395  TestRMSE(orig)=0.0021  TestNLL(z)=-0.7648\n",
            "\n",
            "== MoE (No-Anchor, no-leak) on Concrete ==\n",
            "RMSE (orig) = 0.0025 ± 0.0001\n",
            "NLL  (z)    = -0.6096 ± 0.0536\n"
          ]
        }
      ],
      "source": [
        "# === Run on Housing (UCI Concrete) — MoE w/o Anchor (leak-free) ===\n",
        "set_seed(1)\n",
        "X, y = load_dataset(\"naval\", verbose=True)\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = np.random.choice(np.arange(n), n, replace=False)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS    = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "\n",
        "    rmse, nll, best_ep, _r2 = train_one_split_no_anchor(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses = np.array(rmses, dtype=np.float64)\n",
        "nlls  = np.array(nlls,  dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== MoE (No-Anchor, no-leak) on Concrete ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ez7hQ799bEKc"
      },
      "source": [
        "## Power"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 30,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VyLklNefbKLK",
        "outputId": "231285d3-c086-4ab5-9396-59416a4ecc1e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[power] X.shape=(9568, 4) y.shape=(9568,) | y∈[420.260,495.760]\n",
            "[01/20] MoE best_ep=160  TestRMSE(orig)=4.0409  TestNLL(z)=-0.0505\n",
            "[02/20] MoE best_ep=182  TestRMSE(orig)=4.0588  TestNLL(z)=-0.0333\n",
            "[03/20] MoE best_ep=105  TestRMSE(orig)=4.0045  TestNLL(z)=-0.0843\n",
            "[04/20] MoE best_ep= 94  TestRMSE(orig)=3.9047  TestNLL(z)=-0.1493\n",
            "[05/20] MoE best_ep= 97  TestRMSE(orig)=4.3740  TestNLL(z)=-0.0607\n",
            "[06/20] MoE best_ep=268  TestRMSE(orig)=4.2723  TestNLL(z)=0.3420\n",
            "[07/20] MoE best_ep= 72  TestRMSE(orig)=4.0344  TestNLL(z)=-0.0915\n",
            "[08/20] MoE best_ep=113  TestRMSE(orig)=3.9922  TestNLL(z)=-0.0535\n",
            "[09/20] MoE best_ep=151  TestRMSE(orig)=4.1195  TestNLL(z)=0.0094\n",
            "[10/20] MoE best_ep=187  TestRMSE(orig)=4.0239  TestNLL(z)=-0.1253\n",
            "[11/20] MoE best_ep=138  TestRMSE(orig)=3.8732  TestNLL(z)=-0.0933\n",
            "[12/20] MoE best_ep=148  TestRMSE(orig)=4.0663  TestNLL(z)=-0.1012\n",
            "[13/20] MoE best_ep=119  TestRMSE(orig)=3.9346  TestNLL(z)=-0.0924\n",
            "[14/20] MoE best_ep=202  TestRMSE(orig)=3.6629  TestNLL(z)=-0.1210\n",
            "[15/20] MoE best_ep=173  TestRMSE(orig)=3.9758  TestNLL(z)=-0.1682\n",
            "[16/20] MoE best_ep=111  TestRMSE(orig)=3.6762  TestNLL(z)=-0.1634\n",
            "[17/20] MoE best_ep= 90  TestRMSE(orig)=4.0310  TestNLL(z)=-0.0014\n",
            "[18/20] MoE best_ep=143  TestRMSE(orig)=4.2936  TestNLL(z)=0.1019\n",
            "[19/20] MoE best_ep=178  TestRMSE(orig)=4.1139  TestNLL(z)=0.0265\n",
            "[20/20] MoE best_ep=208  TestRMSE(orig)=3.7647  TestNLL(z)=-0.1403\n",
            "\n",
            "== MoE (No-Anchor, no-leak) on Concrete ==\n",
            "RMSE (orig) = 4.0109 ± 0.0411\n",
            "NLL  (z)    = -0.0525 ± 0.0258\n"
          ]
        }
      ],
      "source": [
        "# === Run on Housing (UCI Concrete) — MoE w/o Anchor (leak-free) ===\n",
        "set_seed(1)\n",
        "X, y = load_dataset(\"power\", verbose=True)\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = np.random.choice(np.arange(n), n, replace=False)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS    = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "\n",
        "    rmse, nll, best_ep, _r2 = train_one_split_no_anchor(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses = np.array(rmses, dtype=np.float64)\n",
        "nlls  = np.array(nlls,  dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== MoE (No-Anchor, no-leak) on Concrete ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9ftzT0MGeT3C"
      },
      "source": [
        "## Protein"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 31,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "PrZ3H1DseTTk",
        "outputId": "3d0a2bc5-9b33-42d3-a4e9-6b84db1c0ece"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[protein] X.shape=(45730, 9) y.shape=(45730,) | y∈[0.000,20.999]\n",
            "[Protein] Downsampled from 45730 to 10000 samples.\n",
            "[01/20] MoE best_ep=199  TestRMSE(orig)=4.5878  TestNLL(z)=0.6299\n",
            "[02/20] MoE best_ep=166  TestRMSE(orig)=4.8964  TestNLL(z)=0.7341\n",
            "[03/20] MoE best_ep=163  TestRMSE(orig)=4.7943  TestNLL(z)=0.6521\n",
            "[04/20] MoE best_ep=136  TestRMSE(orig)=4.7698  TestNLL(z)=0.5796\n",
            "[05/20] MoE best_ep=156  TestRMSE(orig)=4.6493  TestNLL(z)=0.5348\n",
            "[06/20] MoE best_ep=154  TestRMSE(orig)=4.6814  TestNLL(z)=0.7044\n",
            "[07/20] MoE best_ep=121  TestRMSE(orig)=4.8829  TestNLL(z)=0.7170\n",
            "[08/20] MoE best_ep=153  TestRMSE(orig)=4.7358  TestNLL(z)=0.6628\n",
            "[09/20] MoE best_ep=127  TestRMSE(orig)=4.5731  TestNLL(z)=0.7679\n",
            "[10/20] MoE best_ep=265  TestRMSE(orig)=4.4756  TestNLL(z)=0.6018\n",
            "[11/20] MoE best_ep=215  TestRMSE(orig)=4.4832  TestNLL(z)=0.5566\n",
            "[12/20] MoE best_ep=157  TestRMSE(orig)=4.7060  TestNLL(z)=0.6141\n",
            "[13/20] MoE best_ep=193  TestRMSE(orig)=4.6114  TestNLL(z)=0.6180\n",
            "[14/20] MoE best_ep=229  TestRMSE(orig)=4.5261  TestNLL(z)=0.5666\n",
            "[15/20] MoE best_ep=125  TestRMSE(orig)=4.8909  TestNLL(z)=0.5830\n",
            "[16/20] MoE best_ep=243  TestRMSE(orig)=4.7139  TestNLL(z)=0.5786\n",
            "[17/20] MoE best_ep=137  TestRMSE(orig)=4.9578  TestNLL(z)=0.6838\n",
            "[18/20] MoE best_ep=182  TestRMSE(orig)=4.8596  TestNLL(z)=0.6026\n",
            "[19/20] MoE best_ep=173  TestRMSE(orig)=4.9088  TestNLL(z)=0.7040\n",
            "[20/20] MoE best_ep=164  TestRMSE(orig)=4.5730  TestNLL(z)=0.5638\n",
            "\n",
            "== MoE (No-Anchor, no-leak) on Protein (10k downsample) ==\n",
            "RMSE (orig) = 4.7139 ± 0.0340\n",
            "NLL  (z)    = 0.6328 ± 0.0149\n"
          ]
        }
      ],
      "source": [
        "# === Run on Protein (CASP) — MoE w/o Anchor (leak-free, downsample to 10k) ===\n",
        "import gc, numpy as np, torch\n",
        "\n",
        "set_seed(1)\n",
        "# 若你之前的 load_dataset 支持 verbose 参数，就保留；否则去掉 verbose=True\n",
        "X, y = load_dataset(\"protein\", verbose=True)\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "# ---- 下采样到 10k（仅使用随机索引，不看 y，避免任何泄漏）----\n",
        "MAX_N = 10_000\n",
        "orig_n = X.shape[0]\n",
        "if orig_n > MAX_N:\n",
        "    rng_ds = np.random.RandomState(SEED)\n",
        "    keep = rng_ds.choice(orig_n, size=MAX_N, replace=False)\n",
        "    X = X[keep].copy()\n",
        "    y = y[keep].copy()\n",
        "    print(f\"[Protein] Downsampled from {orig_n} to {X.shape[0]} samples.\")\n",
        "else:\n",
        "    print(f\"[Protein] Size {orig_n} <= {MAX_N}, no downsampling.\")\n",
        "\n",
        "# ---- 20 次 90/10 外层划分（与既有协议一致）----\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "rng = np.random.RandomState(SEED)\n",
        "for i in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "# ---- 超参（与前面无 Anchor 版本一致）----\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS    = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, best_ep, _ = train_one_split_no_anchor(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses = np.array(rmses, dtype=np.float64)\n",
        "nlls  = np.array(nlls,  dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== MoE (No-Anchor, no-leak) on Protein (10k downsample) ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J-ceBfWnju3N"
      },
      "source": [
        "\n",
        "## Wine\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 33,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "cLDi4HrmiiKq",
        "outputId": "29164ef3-8d0e-4d62-b7d6-d8ba2c947e43"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[wine] X.shape=(1599, 11) y.shape=(1599,) | y∈[3.000,8.000]\n",
            "[01/20] MoE best_ep= 18  TestRMSE(orig)=0.6403  TestNLL(z)=0.9994\n",
            "[02/20] MoE best_ep= 18  TestRMSE(orig)=0.6666  TestNLL(z)=1.1749\n",
            "[03/20] MoE best_ep=101  TestRMSE(orig)=0.7408  TestNLL(z)=0.6099\n",
            "[04/20] MoE best_ep= 72  TestRMSE(orig)=0.6307  TestNLL(z)=0.4046\n",
            "[05/20] MoE best_ep= 44  TestRMSE(orig)=0.5915  TestNLL(z)=0.4972\n",
            "[06/20] MoE best_ep= 17  TestRMSE(orig)=0.6495  TestNLL(z)=1.0928\n",
            "[07/20] MoE best_ep= 30  TestRMSE(orig)=0.6773  TestNLL(z)=1.0013\n",
            "[08/20] MoE best_ep=201  TestRMSE(orig)=0.7201  TestNLL(z)=6.6193\n",
            "[09/20] MoE best_ep= 84  TestRMSE(orig)=0.6552  TestNLL(z)=0.1123\n",
            "[10/20] MoE best_ep= 41  TestRMSE(orig)=0.6587  TestNLL(z)=1.1153\n",
            "[11/20] MoE best_ep=110  TestRMSE(orig)=0.6476  TestNLL(z)=6.8915\n",
            "[12/20] MoE best_ep= 29  TestRMSE(orig)=0.6484  TestNLL(z)=0.7153\n",
            "[13/20] MoE best_ep= 20  TestRMSE(orig)=0.6568  TestNLL(z)=1.0497\n",
            "[14/20] MoE best_ep=205  TestRMSE(orig)=0.6829  TestNLL(z)=3.5953\n",
            "[15/20] MoE best_ep= 26  TestRMSE(orig)=0.6567  TestNLL(z)=1.0634\n",
            "[16/20] MoE best_ep= 31  TestRMSE(orig)=0.5974  TestNLL(z)=0.8293\n",
            "[17/20] MoE best_ep= 86  TestRMSE(orig)=0.6173  TestNLL(z)=0.4552\n",
            "[18/20] MoE best_ep= 53  TestRMSE(orig)=0.6322  TestNLL(z)=0.7670\n",
            "[19/20] MoE best_ep= 27  TestRMSE(orig)=0.6675  TestNLL(z)=1.2178\n",
            "[20/20] MoE best_ep= 63  TestRMSE(orig)=0.6249  TestNLL(z)=0.1506\n",
            "\n",
            "== MoE (No-Anchor, no-leak) on Concrete ==\n",
            "RMSE (orig) = 0.6531 ± 0.0080\n",
            "NLL  (z)    = 1.5181 ± 0.4307\n"
          ]
        }
      ],
      "source": [
        "# === Run on Housing (UCI Concrete) — MoE w/o Anchor (leak-free) ===\n",
        "set_seed(1)\n",
        "X, y = load_dataset(\"wine\", verbose=True)\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = np.random.choice(np.arange(n), n, replace=False)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS    = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "\n",
        "    rmse, nll, best_ep, _r2 = train_one_split_no_anchor(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses = np.array(rmses, dtype=np.float64)\n",
        "nlls  = np.array(nlls,  dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== MoE (No-Anchor, no-leak) on Concrete ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1Xv5jjAVjzHn"
      },
      "source": [
        "## Yacht"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 35,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "pLGgFHxoj260",
        "outputId": "4d7d4521-e373-47cb-d2b6-d2ab4d4310f6"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[yacht] X.shape=(308, 6) y.shape=(308,) | y∈[0.010,62.420]\n",
            "[01/20] MoE best_ep= 49  TestRMSE(orig)=2.8491  TestNLL(z)=-0.6580\n",
            "[02/20] MoE best_ep= 64  TestRMSE(orig)=6.3435  TestNLL(z)=-0.8389\n",
            "[03/20] MoE best_ep= 61  TestRMSE(orig)=3.0756  TestNLL(z)=-0.3659\n",
            "[04/20] MoE best_ep=380  TestRMSE(orig)=3.8330  TestNLL(z)=-0.4752\n",
            "[05/20] MoE best_ep=183  TestRMSE(orig)=4.8626  TestNLL(z)=-0.7710\n",
            "[06/20] MoE best_ep= 55  TestRMSE(orig)=4.0677  TestNLL(z)=-0.4541\n",
            "[07/20] MoE best_ep=371  TestRMSE(orig)=1.2637  TestNLL(z)=-1.2975\n",
            "[08/20] MoE best_ep=326  TestRMSE(orig)=3.0515  TestNLL(z)=-0.9351\n",
            "[09/20] MoE best_ep= 49  TestRMSE(orig)=4.8756  TestNLL(z)=-0.4014\n",
            "[10/20] MoE best_ep= 48  TestRMSE(orig)=2.7649  TestNLL(z)=-0.5411\n",
            "[11/20] MoE best_ep=138  TestRMSE(orig)=4.0778  TestNLL(z)=0.2368\n",
            "[12/20] MoE best_ep=381  TestRMSE(orig)=3.0999  TestNLL(z)=0.2827\n",
            "[13/20] MoE best_ep=180  TestRMSE(orig)=5.2409  TestNLL(z)=2.3303\n",
            "[14/20] MoE best_ep=156  TestRMSE(orig)=4.7693  TestNLL(z)=-0.8719\n",
            "[15/20] MoE best_ep=201  TestRMSE(orig)=3.8388  TestNLL(z)=3.2956\n",
            "[16/20] MoE best_ep=240  TestRMSE(orig)=3.0965  TestNLL(z)=-0.6161\n",
            "[17/20] MoE best_ep= 29  TestRMSE(orig)=4.6860  TestNLL(z)=-0.6509\n",
            "[18/20] MoE best_ep= 34  TestRMSE(orig)=5.5298  TestNLL(z)=-0.5265\n",
            "[19/20] MoE best_ep= 45  TestRMSE(orig)=4.5363  TestNLL(z)=1.5535\n",
            "[20/20] MoE best_ep= 61  TestRMSE(orig)=8.0117  TestNLL(z)=6.5179\n",
            "\n",
            "== MoE (No-Anchor, no-leak) ==\n",
            "RMSE (orig) = 4.1937 ± 0.3299\n",
            "NLL  (z)    = 0.2407 ± 0.4209\n"
          ]
        }
      ],
      "source": [
        "# === Run on Housing (UCI Concrete) — MoE w/o Anchor (leak-free) ===\n",
        "set_seed(1)\n",
        "X, y = load_dataset(\"yacht\", verbose=True)\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = np.random.choice(np.arange(n), n, replace=False)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS    = 1e-3, 400\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "\n",
        "    rmse, nll, best_ep, _r2 = train_one_split_no_anchor(\n",
        "        X[tr_idx].copy(), y[tr_idx].copy(),\n",
        "        X[te_idx].copy(),  y[te_idx].copy(),\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS,\n",
        "        seed=SEED + itr\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] MoE best_ep={best_ep:3d}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses = np.array(rmses, dtype=np.float64)\n",
        "nlls  = np.array(nlls,  dtype=np.float64)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== MoE (No-Anchor, no-leak) ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
