{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "a67dc13a-f1f9-45c8-89af-468824495286",
      "metadata": {
        "id": "a67dc13a-f1f9-45c8-89af-468824495286"
      },
      "source": [
        "\n",
        "# **Empirical Expriments of Anchor-MoE**\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "305e6090-f377-4cf1-808a-89c67049fd93",
      "metadata": {
        "id": "305e6090-f377-4cf1-808a-89c67049fd93"
      },
      "source": [
        "# Public Dataset\n",
        "\n",
        "10 UCI benchmark regression datasets\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0dfd64d2-2fb8-4c86-95fe-7485cbd41450",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0dfd64d2-2fb8-4c86-95fe-7485cbd41450",
        "outputId": "c68e344e-e684-493a-bcee-6c4aa995998e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting ngboost\n",
            "  Downloading ngboost-0.5.6-py3-none-any.whl.metadata (4.0 kB)\n",
            "Collecting lifelines>=0.25 (from ngboost)\n",
            "  Downloading lifelines-0.30.0-py3-none-any.whl.metadata (3.2 kB)\n",
            "Requirement already satisfied: numpy>=1.21.2 in /usr/local/lib/python3.11/dist-packages (from ngboost) (2.0.2)\n",
            "Requirement already satisfied: scikit-learn<2.0,>=1.6 in /usr/local/lib/python3.11/dist-packages (from ngboost) (1.6.1)\n",
            "Requirement already satisfied: scipy>=1.7.2 in /usr/local/lib/python3.11/dist-packages (from ngboost) (1.16.1)\n",
            "Requirement already satisfied: tqdm>=4.3 in /usr/local/lib/python3.11/dist-packages (from ngboost) (4.67.1)\n",
            "Requirement already satisfied: pandas>=2.1 in /usr/local/lib/python3.11/dist-packages (from lifelines>=0.25->ngboost) (2.2.2)\n",
            "Requirement already satisfied: matplotlib>=3.0 in /usr/local/lib/python3.11/dist-packages (from lifelines>=0.25->ngboost) (3.10.0)\n",
            "Collecting autograd>=1.5 (from lifelines>=0.25->ngboost)\n",
            "  Downloading autograd-1.8.0-py3-none-any.whl.metadata (7.5 kB)\n",
            "Collecting autograd-gamma>=0.3 (from lifelines>=0.25->ngboost)\n",
            "  Downloading autograd-gamma-0.5.0.tar.gz (4.0 kB)\n",
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "Collecting formulaic>=0.2.2 (from lifelines>=0.25->ngboost)\n",
            "  Downloading formulaic-1.2.0-py3-none-any.whl.metadata (7.0 kB)\n",
            "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.6->ngboost) (1.5.1)\n",
            "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<2.0,>=1.6->ngboost) (3.6.0)\n",
            "Collecting interface-meta>=1.2.0 (from formulaic>=0.2.2->lifelines>=0.25->ngboost)\n",
            "  Downloading interface_meta-1.3.0-py3-none-any.whl.metadata (6.7 kB)\n",
            "Requirement already satisfied: narwhals>=1.17 in /usr/local/lib/python3.11/dist-packages (from formulaic>=0.2.2->lifelines>=0.25->ngboost) (2.1.1)\n",
            "Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.11/dist-packages (from formulaic>=0.2.2->lifelines>=0.25->ngboost) (4.14.1)\n",
            "Requirement already satisfied: wrapt>=1.0 in /usr/local/lib/python3.11/dist-packages (from formulaic>=0.2.2->lifelines>=0.25->ngboost) (1.17.3)\n",
            "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=3.0->lifelines>=0.25->ngboost) (1.3.3)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=3.0->lifelines>=0.25->ngboost) (0.12.1)\n",
            "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=3.0->lifelines>=0.25->ngboost) (4.59.0)\n",
            "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=3.0->lifelines>=0.25->ngboost) (1.4.9)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=3.0->lifelines>=0.25->ngboost) (25.0)\n",
            "Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=3.0->lifelines>=0.25->ngboost) (11.3.0)\n",
            "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=3.0->lifelines>=0.25->ngboost) (3.2.3)\n",
            "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.11/dist-packages (from matplotlib>=3.0->lifelines>=0.25->ngboost) (2.9.0.post0)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas>=2.1->lifelines>=0.25->ngboost) (2025.2)\n",
            "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas>=2.1->lifelines>=0.25->ngboost) (2025.2)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.7->matplotlib>=3.0->lifelines>=0.25->ngboost) (1.17.0)\n",
            "Downloading ngboost-0.5.6-py3-none-any.whl (35 kB)\n",
            "Downloading lifelines-0.30.0-py3-none-any.whl (349 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m349.3/349.3 kB\u001b[0m \u001b[31m20.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading autograd-1.8.0-py3-none-any.whl (51 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m51.5/51.5 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading formulaic-1.2.0-py3-none-any.whl (117 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m117.2/117.2 kB\u001b[0m \u001b[31m13.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading interface_meta-1.3.0-py3-none-any.whl (14 kB)\n",
            "Building wheels for collected packages: autograd-gamma\n",
            "  Building wheel for autograd-gamma (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for autograd-gamma: filename=autograd_gamma-0.5.0-py3-none-any.whl size=4119 sha256=15a25619e34ae2d1c50e9c3586dace334dfe581f32863180afb9658467f7f8be\n",
            "  Stored in directory: /root/.cache/pip/wheels/8b/67/f4/2caaae2146198dcb824f31a303833b07b14a5ec863fb3acd7b\n",
            "Successfully built autograd-gamma\n",
            "Installing collected packages: interface-meta, autograd, autograd-gamma, formulaic, lifelines, ngboost\n",
            "Successfully installed autograd-1.8.0 autograd-gamma-0.5.0 formulaic-1.2.0 interface-meta-1.3.0 lifelines-0.30.0 ngboost-0.5.6\n",
            "Collecting ucimlrepo\n",
            "  Downloading ucimlrepo-0.0.7-py3-none-any.whl.metadata (5.5 kB)\n",
            "Requirement already satisfied: pandas>=1.0.0 in /usr/local/lib/python3.11/dist-packages (from ucimlrepo) (2.2.2)\n",
            "Requirement already satisfied: certifi>=2020.12.5 in /usr/local/lib/python3.11/dist-packages (from ucimlrepo) (2025.8.3)\n",
            "Requirement already satisfied: numpy>=1.23.2 in /usr/local/lib/python3.11/dist-packages (from pandas>=1.0.0->ucimlrepo) (2.0.2)\n",
            "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas>=1.0.0->ucimlrepo) (2.9.0.post0)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas>=1.0.0->ucimlrepo) (2025.2)\n",
            "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas>=1.0.0->ucimlrepo) (2025.2)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas>=1.0.0->ucimlrepo) (1.17.0)\n",
            "Downloading ucimlrepo-0.0.7-py3-none-any.whl (8.0 kB)\n",
            "Installing collected packages: ucimlrepo\n",
            "Successfully installed ucimlrepo-0.0.7\n"
          ]
        }
      ],
      "source": [
        "!pip install ngboost\n",
        "!pip install ucimlrepo"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "WGsN1fc911vG",
      "metadata": {
        "id": "WGsN1fc911vG"
      },
      "source": [
        "# **Model Definition**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "oKPIrPy11zFu",
      "metadata": {
        "id": "oKPIrPy11zFu"
      },
      "outputs": [],
      "source": [
        "\n",
        "import os\n",
        "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\" \n",
        "\n",
        "import math, gc\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "from sklearn.model_selection import train_test_split\n",
        "from sklearn.metrics import mean_squared_error\n",
        "from sklearn.ensemble import GradientBoostingRegressor\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "\n",
        "\n",
        "LOG2PI  = math.log(2*math.pi)\n",
        "DEVICE  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "torch.set_default_dtype(torch.float32)\n",
        "\n",
        "\n",
        "def rmse_score(y_true, y_pred):\n",
        "    return float(np.sqrt(mean_squared_error(y_true, y_pred)))\n",
        "\n",
        "def load_dataset(name: str):\n",
        "    def _clean_numeric(df: pd.DataFrame) -> pd.DataFrame:\n",
        "        df = df.apply(pd.to_numeric, errors=\"coerce\")\n",
        "        df = df.replace([np.inf, -np.inf], np.nan).dropna(axis=0)\n",
        "        return df\n",
        "\n",
        "    name = name.lower().strip()\n",
        "\n",
        "    if name == \"yacht\":\n",
        "        df = pd.read_csv(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/00243/yacht_hydrodynamics.data\",\n",
        "            header=None, sep=r\"\\s+\", engine=\"python\", comment=\"#\", skip_blank_lines=True\n",
        "        )\n",
        "        if df.shape[1] == 1:\n",
        "            df = df[0].astype(str).str.strip().str.split(r\"\\s+\", expand=True)\n",
        "        df = _clean_numeric(df); assert df.shape[1] >= 7\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        print(f\"[Yacht] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"housing\":\n",
        "        df = pd.read_csv(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data\",\n",
        "            header=None, sep=r\"\\s+\", engine=\"python\", comment=\"#\", skip_blank_lines=True\n",
        "        )\n",
        "        if df.shape[1] == 1:\n",
        "            df = df[0].astype(str).str.strip().str.split(r\"\\s+\", expand=True)\n",
        "        df = _clean_numeric(df); assert df.shape[1] >= 14\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        print(f\"[Housing] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"concrete\":\n",
        "        df = pd.read_excel(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/concrete/compressive/Concrete_Data.xls\"\n",
        "        )\n",
        "        df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        print(f\"[Concrete] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"wine\":\n",
        "        df = pd.read_csv(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv\",\n",
        "            delimiter=\";\"\n",
        "        )\n",
        "        df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        print(f\"[WineRed] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"power\":\n",
        "       local = \"data/power.xlsx\" \n",
        "       if not os.path.exists(local):\n",
        "          raise FileNotFoundError(\n",
        "            \"[power]'power.xlsx'（UCI CCPP 的 Folds5x2_pp.xlsx）。\")\n",
        "       df = pd.read_excel(local)\n",
        "       df = _clean_numeric(df)\n",
        "    if name == \"energy\":\n",
        "        df = pd.read_excel(\n",
        "            \"https://archive.ics.uci.edu/ml/machine-learning-databases/00242/ENB2012_data.xlsx\"\n",
        "        ).iloc[:, :-1]\n",
        "        df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64)  \n",
        "        y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        print(f\"[Energy] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"protein\":\n",
        "       local = \"data/protein.csv\"  \n",
        "       df = pd.read_csv(local, sep=None, engine=\"python\") \n",
        "       df = _clean_numeric(df)\n",
        "       y = df.iloc[:, 0].to_numpy(np.float64)\n",
        "       X = df.iloc[:, 1:].to_numpy(np.float64)\n",
        "       y_min, y_max = float(np.nanmin(y)), float(np.nanmax(y))\n",
        "       print(f\"[Protein(local)] X.shape={X.shape} y.shape={y.shape} | y∈[{y_min:.3f},{y_max:.3f}]\")\n",
        "       return X, y\n",
        "\n",
        "    if name == \"naval\":\n",
        "       local = \"data/naval.txt\"  \n",
        "       df = pd.read_csv(local, sep=r\"\\s+\", header=None, engine=\"python\")\n",
        "       df = df.iloc[:, :-1]\n",
        "       df = _clean_numeric(df)\n",
        "       X = df.iloc[:, :-1].to_numpy(np.float64)\n",
        "       y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "       print(f\"[Naval(local)] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.6f},{y.max():.6f}]\")\n",
        "       return X, y\n",
        "\n",
        "    if name == \"kin8nm\":\n",
        "        local = \"data/kin8nm.csv\"\n",
        "        if not os.path.exists(local): raise FileNotFoundError(\"[kin8nm] require local 'kin8nm.csv'.\")\n",
        "        df = pd.read_csv(local); df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        print(f\"[Kin8nm] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    if name == \"msd\":\n",
        "        local = \"data/YearPredictionMSD.txt\"\n",
        "        if not os.path.exists(local): raise FileNotFoundError(\"[msd] require local 'YearPredictionMSD.txt'.\")\n",
        "        df = pd.read_csv(local, header=None); df = df.iloc[:, ::-1]\n",
        "        df = _clean_numeric(df)\n",
        "        X = df.iloc[:, :-1].to_numpy(np.float64); y = df.iloc[:, -1].to_numpy(np.float64)\n",
        "        print(f\"[MSD] X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "        return X, y\n",
        "\n",
        "    raise ValueError(\"Unknown dataset. options: 'housing','concrete','wine','kin8nm','naval','power','energy','protein','yacht','msd'\")\n",
        "\n",
        "def zscore_fit(y_train):\n",
        "    my = float(y_train.mean()); sy = float(y_train.std() + 1e-8)\n",
        "    return my, sy\n",
        "\n",
        "def _topk_mask(w, k=4):\n",
        "    _, topi = torch.topk(w, k, dim=-1)\n",
        "    mask = torch.zeros_like(w).scatter_(-1, topi, 1.0)\n",
        "    w2 = w * mask\n",
        "    return w2 / (w2.sum(dim=-1, keepdim=True) + 1e-12)\n",
        "\n",
        "def _topk_mask_smooth(w, k=2, eps=0.05):\n",
        "    _, topi = torch.topk(w, k, dim=-1)\n",
        "    mask = torch.zeros_like(w).scatter_(-1, topi, 1.0)\n",
        "    w_top = w * mask\n",
        "    w_top = w_top / (w_top.sum(dim=-1, keepdim=True) + 1e-12)\n",
        "    return (1.0 - eps) * w_top + (eps / k) * mask\n",
        "\n",
        "class Projection(nn.Module):\n",
        "    def __init__(self, d, D):\n",
        "        super().__init__()\n",
        "        self.w = nn.Linear(d, D, bias=True)\n",
        "        nn.init.xavier_uniform_(self.w.weight); nn.init.zeros_(self.w.bias)\n",
        "    def forward(self, x): return self.w(x)\n",
        "\n",
        "class Window(nn.Module):\n",
        "    def __init__(self, K, D, min_log_s=-2.5, max_log_s=1.0):\n",
        "        super().__init__()\n",
        "        self.c         = nn.Parameter(torch.randn(K, D))\n",
        "        self.log_s     = nn.Parameter(torch.zeros(K, D))\n",
        "        self.min_log_s = min_log_s; self.max_log_s = max_log_s\n",
        "    def forward(self, z):\n",
        "        log_s = torch.clamp(self.log_s, min=self.min_log_s, max=self.max_log_s)\n",
        "        diff2 = ((z[:, None] - self.c)**2) / (2 * torch.exp(log_s)**2)\n",
        "        return torch.exp(-diff2.sum(dim=-1)) + 1e-12  # [B,K]\n",
        "\n",
        "class Router(nn.Module):\n",
        "    def __init__(self, D, K):\n",
        "        super().__init__()\n",
        "        self.q   = nn.Linear(D, 64)\n",
        "        self.k   = nn.Parameter(torch.randn(K, 64))\n",
        "        self.tau = 3.0\n",
        "    def forward(self, z):\n",
        "        logits = (self.q(z) @ self.k.T) / math.sqrt(64)\n",
        "        return F.softmax(logits / self.tau, dim=-1)\n",
        "\n",
        "class ExpertMDN(nn.Module):\n",
        "    def __init__(self, d, h, nc, sigma_min=5e-2, sigma_max=1.0, learn_mean=True):\n",
        "        super().__init__()\n",
        "        self.net = nn.Sequential(\n",
        "            nn.Linear(d, h), nn.ReLU(),\n",
        "            nn.Linear(h, h), nn.ReLU()\n",
        "        )\n",
        "        self.logits = nn.Linear(h, nc)\n",
        "        self.learn_mean = learn_mean\n",
        "        if learn_mean:\n",
        "            self.means  = nn.Linear(h, nc)\n",
        "        self.log_sc = nn.Linear(h, nc)\n",
        "        self.sigma_min = float(sigma_min)\n",
        "        self.sigma_max = float(sigma_max)\n",
        "        for m in self.modules():\n",
        "            if isinstance(m, nn.Linear):\n",
        "                nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)\n",
        "    def forward(self, x):\n",
        "        h  = self.net(x)\n",
        "        pi = F.softmax(self.logits(h), dim=-1)\n",
        "        mu = self.means(h) if self.learn_mean else None\n",
        "        sg = torch.exp(self.log_sc(h)).clamp(self.sigma_min, self.sigma_max)\n",
        "        return pi, mu, sg\n",
        "\n",
        "class BLRMoE(nn.Module):\n",
        "    \"\"\"\n",
        "    mean_mode: 'anchor' | 'anchor+delta' | 'free'\n",
        "    \"\"\"\n",
        "    def __init__(self, d, D, K, hid, nc,\n",
        "                 mean_mode='anchor+delta', delta_l2=1e-3,\n",
        "                 w_ent_warm=+1e-3, w_ent_cool=-2e-4, l2_win=1e-4, lb_coef=1e-3,\n",
        "                 sigma_min=5e-2, sigma_max=1.0,\n",
        "                 topk=2, smooth_eps=0.05):\n",
        "        super().__init__()\n",
        "        assert mean_mode in ['anchor','anchor+delta','free']\n",
        "        self.mean_mode = mean_mode\n",
        "        self.delta_l2  = float(delta_l2)\n",
        "        self.proj   = Projection(d, D)\n",
        "        self.win    = Window(K, D)\n",
        "        self.router = Router(D, K)\n",
        "        learn_mean = (mean_mode != 'anchor')\n",
        "        self.exps   = nn.ModuleList([ExpertMDN(d, hid, nc, sigma_min, sigma_max, learn_mean=learn_mean) for _ in range(K)])\n",
        "        self.w_ent_warm = w_ent_warm\n",
        "        self.w_ent_cool = w_ent_cool\n",
        "        self.l2_win     = l2_win\n",
        "        self.lb_coef    = lb_coef\n",
        "        self.topk       = topk\n",
        "        self.smooth_eps = smooth_eps\n",
        "        self.K = K; self.nc = nc\n",
        "\n",
        "    def _mixture_params(self, X, train=True):\n",
        "        z = self.proj(X)\n",
        "        w = self.win(z) * self.router(z)\n",
        "        w = w / (w.sum(dim=-1, keepdim=True) + 1e-12)\n",
        "        w = _topk_mask_smooth(w, k=self.topk, eps=self.smooth_eps) if train else _topk_mask(w, k=self.topk)\n",
        "\n",
        "        B, K, C = X.size(0), self.K, self.nc\n",
        "        Pi = torch.full((B, K, C), 1.0/C, device=X.device)\n",
        "        Mu = torch.zeros(B, K, C, device=X.device)\n",
        "        Sg = torch.ones(B, K, C, device=X.device)\n",
        "\n",
        "        _, topi = torch.topk(w, self.topk, dim=-1)\n",
        "        uniq = torch.unique(topi)\n",
        "        for j in uniq.tolist():\n",
        "            pi_j, mu_j, sg_j = self.exps[j](X) \n",
        "            mask = (topi == j).any(dim=1).float().unsqueeze(-1)  \n",
        "            Pi[:, j, :] = pi_j * mask + Pi[:, j, :]*(1 - mask)\n",
        "            if mu_j is not None:\n",
        "                Mu[:, j, :] = mu_j * mask + Mu[:, j, :]*(1 - mask)\n",
        "            Sg[:, j, :] = sg_j * mask + Sg[:, j, :]*(1 - mask)\n",
        "\n",
        "        return w, Pi, Mu, Sg, z\n",
        "\n",
        "    def nll(self, X, y_z, mu_anchor_z=None, epoch=1, warmup_epochs=150):\n",
        "        train_flag = self.training\n",
        "        w, Pi, Mu, Sg, z = self._mixture_params(X, train=train_flag)\n",
        "\n",
        "        if self.mean_mode == 'anchor':\n",
        "            assert mu_anchor_z is not None\n",
        "            Mu_eff = mu_anchor_z[:, None, None]\n",
        "            delta_pen = 0.0\n",
        "        elif self.mean_mode == 'anchor+delta':\n",
        "            assert mu_anchor_z is not None\n",
        "            Mu_eff = mu_anchor_z[:, None, None] + Mu\n",
        "            delta_pen = (Mu**2).mean() * self.delta_l2\n",
        "        else:\n",
        "            Mu_eff = Mu\n",
        "            delta_pen = 0.0\n",
        "\n",
        "        yv = y_z[:, None, None]\n",
        "        logp = -0.5 * ((yv - Mu_eff)/Sg)**2 - torch.log(Sg) - 0.5*LOG2PI  \n",
        "\n",
        "        w3 = w[:, :, None]\n",
        "        logw = torch.where(w3 > 0, torch.log(w3 + 1e-12), torch.full_like(w3, -1e9))\n",
        "        logpi= torch.log(Pi + 1e-12)\n",
        "\n",
        "        logmix = torch.logsumexp(logw + logpi + logp, dim=(1,2))\n",
        "        nll = -logmix.mean()\n",
        "\n",
        "        if train_flag:\n",
        "            w_ent = self.w_ent_warm if epoch <= warmup_epochs else self.w_ent_cool\n",
        "            p = self.router(z)  \n",
        "            ent = (p * torch.log(p + 1e-12)).sum(dim=1).mean()\n",
        "            l2w = (self.win.log_s**2).mean()\n",
        "            rho = p.mean(dim=0)\n",
        "            lb_loss = ((rho - 1.0/p.size(1))**2).sum()\n",
        "\n",
        "            nll = (nll\n",
        "                   + w_ent * ent\n",
        "                   + self.l2_win * l2w\n",
        "                   + self.lb_coef * lb_loss\n",
        "                   + delta_pen)\n",
        "        return nll\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def predict_mean_var(self, X, mu_anchor_z=None):\n",
        "        w, Pi, Mu, Sg, _ = self._mixture_params(X, train=False)\n",
        "        if self.mean_mode == 'anchor':\n",
        "            assert mu_anchor_z is not None\n",
        "            Mu_eff = mu_anchor_z[:, None, None]\n",
        "        elif self.mean_mode == 'anchor+delta':\n",
        "            assert mu_anchor_z is not None\n",
        "            Mu_eff = mu_anchor_z[:, None, None] + Mu\n",
        "        else:\n",
        "            Mu_eff = Mu\n",
        "\n",
        "        mu_e  = (Pi * Mu_eff).sum(dim=2)             \n",
        "        m2_e  = (Pi * (Sg**2 + Mu_eff**2)).sum(dim=2) \n",
        "        mu_z  = (w * mu_e).sum(dim=1)                 \n",
        "        second= (w * m2_e).sum(dim=1)                 \n",
        "        var_z = torch.clamp(second - mu_z**2, min=1e-9)\n",
        "        return mu_z, var_z\n",
        "\n",
        "def train_one_split_with_anchor(\n",
        "    X_all, y_all, X_te, y_te,\n",
        "    standardize_x=True,\n",
        "    D=2, K=8, HID=128, NC=3,\n",
        "    LR=1e-3, EPOCHS=400,\n",
        "    MEAN_MODE='anchor+delta',\n",
        "    DELTA_L2=1e-3,\n",
        "    SIGMA_MIN=5e-2, SIGMA_MAX=1.0,\n",
        "    TOPK=2, SMOOTH_EPS=0.05\n",
        "):\n",
        "    X_tv, X_cal, y_tv, y_cal = train_test_split(X_all, y_all, test_size=0.125, random_state=1)\n",
        "    \n",
        "    X_tr, X_va, y_tr, y_va   = train_test_split(X_tv,  y_tv,  test_size=0.2,   random_state=1)\n",
        "\n",
        "    gbdt_full_tr = GradientBoostingRegressor(n_estimators=2000, learning_rate=0.05,\n",
        "                                             max_depth=3, random_state=1, subsample=1.0)\n",
        "    gbdt_full_tr.fit(X_tr, y_tr)\n",
        "    val_rmse = [rmse_score(y_va, p) for p in gbdt_full_tr.staged_predict(X_va)]\n",
        "    best_it  = int(np.argmin(val_rmse)) + 1\n",
        "\n",
        "    gbdt_sub   = GradientBoostingRegressor(n_estimators=best_it, learning_rate=0.05,\n",
        "                                           max_depth=3, random_state=1, subsample=1.0).fit(X_tr, y_tr)\n",
        "    gbdt_final = GradientBoostingRegressor(n_estimators=best_it, learning_rate=0.05,\n",
        "                                           max_depth=3, random_state=1, subsample=1.0).fit(X_tv, y_tv)\n",
        "\n",
        "    my_tr, sy_tr = zscore_fit(y_tr)   \n",
        "    my_tv, sy_tv = zscore_fit(y_tv)   \n",
        "\n",
        "    mu_tr_anchor = (gbdt_sub.predict(X_tr) - my_tr) / sy_tr\n",
        "    mu_va_anchor = (gbdt_sub.predict(X_va) - my_tr) / sy_tr\n",
        "\n",
        "    X_tr_aug = np.column_stack([X_tr, mu_tr_anchor])\n",
        "    X_va_aug = np.column_stack([X_va, mu_va_anchor])\n",
        "\n",
        "    if standardize_x:\n",
        "        mx_tr = X_tr_aug.mean(0, keepdims=True); sx_tr = X_tr_aug.std(0, keepdims=True) + 1e-8\n",
        "    else:\n",
        "        mx_tr = np.zeros((1, X_tr_aug.shape[1])); sx_tr = np.ones((1, X_tr_aug.shape[1]))\n",
        "\n",
        "    def to_tensor(x, yz=None, mx=None, sx=None):\n",
        "        Xt = torch.tensor(((x - mx)/sx).astype(np.float32), device=DEVICE)\n",
        "        if yz is None: return Xt\n",
        "        Yt = torch.tensor(yz.astype(np.float32), device=DEVICE)\n",
        "        return Xt, Yt\n",
        "\n",
        "    y_tr_z = (y_tr - my_tr) / sy_tr\n",
        "    y_va_z = (y_va - my_tr) / sy_tr\n",
        "\n",
        "    model = BLRMoE(\n",
        "        d=X_tr_aug.shape[1], D=D, K=K, hid=HID, nc=NC,\n",
        "        mean_mode=MEAN_MODE, delta_l2=DELTA_L2,\n",
        "        sigma_min=SIGMA_MIN, sigma_max=SIGMA_MAX,\n",
        "        topk=TOPK, smooth_eps=SMOOTH_EPS\n",
        "    ).to(DEVICE)\n",
        "    opt = optim.AdamW(model.parameters(), lr=LR, weight_decay=3e-4)\n",
        "\n",
        "    Xtr_t, ytr_t = to_tensor(X_tr_aug, y_tr_z, mx_tr, sx_tr)\n",
        "    Xva_t, yva_t = to_tensor(X_va_aug, y_va_z, mx_tr, sx_tr)\n",
        "    mu_tr_t = torch.tensor(mu_tr_anchor.astype(np.float32), device=DEVICE)\n",
        "    mu_va_t = torch.tensor(mu_va_anchor.astype(np.float32), device=DEVICE)\n",
        "\n",
        "    best_ep, best_vnll, best_state = 0, +1e9, None\n",
        "    for ep in range(1, EPOCHS+1):\n",
        "        model.train()\n",
        "        opt.zero_grad()\n",
        "        loss = model.nll(Xtr_t, ytr_t, mu_anchor_z=mu_tr_t, epoch=ep, warmup_epochs=150)\n",
        "        loss.backward(); nn.utils.clip_grad_norm_(model.parameters(), 2.0); opt.step()\n",
        "        model.router.tau = max(model.router.tau * 0.995, 1.0)\n",
        "\n",
        "        model.eval()\n",
        "        with torch.no_grad():\n",
        "            vnll = float(model.nll(Xva_t, yva_t, mu_anchor_z=mu_va_t, epoch=ep, warmup_epochs=150).cpu().item())\n",
        "        if vnll < best_vnll:\n",
        "            best_vnll = vnll; best_ep = ep\n",
        "            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}\n",
        "\n",
        "    mu_tv_anchor = (gbdt_final.predict(X_tv) - my_tv) / sy_tv\n",
        "    mu_te_anchor = (gbdt_final.predict(X_te) - my_tv) / sy_tv\n",
        "    mu_cal_anchor= (gbdt_final.predict(X_cal)- my_tv) / sy_tv\n",
        "\n",
        "    X_tv_aug  = np.column_stack([X_tv,  mu_tv_anchor])\n",
        "    X_te_aug  = np.column_stack([X_te,  mu_te_anchor])\n",
        "    X_cal_aug = np.column_stack([X_cal, mu_cal_anchor])\n",
        "\n",
        "    if standardize_x:\n",
        "        mx_tv = X_tv_aug.mean(0, keepdims=True); sx_tv = X_tv_aug.std(0, keepdims=True) + 1e-8\n",
        "    else:\n",
        "        mx_tv = np.zeros((1, X_tv_aug.shape[1])); sx_tv = np.ones((1, X_tv_aug.shape[1]))\n",
        "\n",
        "    y_tv_z = (y_tv - my_tv) / sy_tv\n",
        "    y_te_z = (y_te - my_tv) / sy_tv\n",
        "\n",
        "    model.load_state_dict(best_state)\n",
        "    model.train()\n",
        "    opt = optim.AdamW(model.parameters(), lr=LR, weight_decay=3e-4)\n",
        "\n",
        "    Xtv_t, ytv_t = to_tensor(X_tv_aug, y_tv_z, mx_tv, sx_tv)\n",
        "    Xte_t        = to_tensor(X_te_aug, None,   mx_tv, sx_tv)\n",
        "    Xcal_t       = to_tensor(X_cal_aug, None,  mx_tv, sx_tv)\n",
        "    mu_tv_t  = torch.tensor(mu_tv_anchor.astype(np.float32),  device=DEVICE)\n",
        "    mu_te_t  = torch.tensor(mu_te_anchor.astype(np.float32),  device=DEVICE)\n",
        "    mu_cal_t = torch.tensor(mu_cal_anchor.astype(np.float32), device=DEVICE)\n",
        "\n",
        "    for ep in range(1, best_ep+1):\n",
        "        opt.zero_grad()\n",
        "        loss = model.nll(Xtv_t, ytv_t, mu_anchor_z=mu_tv_t, epoch=ep, warmup_epochs=150)\n",
        "        loss.backward(); nn.utils.clip_grad_norm_(model.parameters(), 2.0); opt.step()\n",
        "        model.router.tau = max(model.router.tau * 0.995, 1.0)\n",
        "\n",
        "    model.eval()\n",
        "    with torch.no_grad():\n",
        "        yte_t = torch.tensor(y_te_z.astype(np.float32), device=DEVICE)\n",
        "        test_nll = float(model.nll(Xte_t, yte_t, mu_anchor_z=mu_te_t).cpu().item())\n",
        "\n",
        "        mu_z_cal, _ = model.predict_mean_var(Xcal_t, mu_anchor_z=mu_cal_t)\n",
        "        mu_cal_orig = mu_z_cal.cpu().numpy().astype(np.float64) * sy_tv + my_tv\n",
        "\n",
        "        mu_z_te, _  = model.predict_mean_var(Xte_t,  mu_anchor_z=mu_te_t)\n",
        "        mu_te_orig  = mu_z_te.cpu().numpy().astype(np.float64) * sy_tv + my_tv\n",
        "\n",
        "    A = np.vstack([mu_cal_orig, np.ones_like(mu_cal_orig)]).T\n",
        "    ab, *_ = np.linalg.lstsq(A, y_cal.astype(np.float64), rcond=None)\n",
        "    a, b = float(ab[0]), float(ab[1])\n",
        "\n",
        "    mu_te_cal  = a * mu_te_orig + b\n",
        "    rmse = rmse_score(y_te.astype(np.float64), mu_te_cal)\n",
        "\n",
        "    return rmse, test_nll, best_it, best_ep\n",
        "\n",
        "def main():\n",
        "    DATASET = \"housing\" \n",
        "    X, y = load_dataset(DATASET)\n",
        "    print(f\"== Dataset={DATASET} X.shape={X.shape}\")\n",
        "\n",
        "    SEED = 1\n",
        "    np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.manual_seed_all(SEED)\n",
        "        torch.backends.cudnn.deterministic = True\n",
        "        torch.backends.cudnn.benchmark = False\n",
        "\n",
        "    STANDARDIZE_X = True\n",
        "    D, K, HID, NC = 2, 8, 128, 3\n",
        "    LR, EPOCHS    = 1e-3, 400\n",
        "    MEAN_MODE, DELTA_L2 = 'anchor+delta', 3e-3\n",
        "    SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "    TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "    n = X.shape[0]\n",
        "    splits = []\n",
        "    for _ in range(20):\n",
        "        perm = np.random.choice(np.arange(n), n, replace=False)\n",
        "        splits.append((perm[: round(n*0.9)], perm[round(n*0.9):]))\n",
        "\n",
        "    rmses, nlls = [], []\n",
        "    for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "        rmse, nll, m_it, moe_ep = train_one_split_with_anchor(\n",
        "            X[tr_idx], y[tr_idx], X[te_idx], y[te_idx],\n",
        "            standardize_x=STANDARDIZE_X,\n",
        "            D=D, K=K, HID=HID, NC=NC,\n",
        "            LR=LR, EPOCHS=EPOCHS,\n",
        "            MEAN_MODE=MEAN_MODE, DELTA_L2=DELTA_L2,\n",
        "            SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "            TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS\n",
        "        )\n",
        "        print(f\"[{itr:02d}/20] GBDT best_iter={m_it}  MoE best_ep={moe_ep}  \"\n",
        "              f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "        rmses.append(rmse); nlls.append(nll)\n",
        "        gc.collect();\n",
        "        if torch.cuda.is_available():\n",
        "            torch.cuda.empty_cache()\n",
        "\n",
        "    rmses = np.array(rmses, dtype=np.float64)\n",
        "    nlls  = np.array(nlls,  dtype=np.float64)\n",
        "    se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "    print(\"\\n== Anchor-MoE (coupled anchor feature, no-leak, topk-fixed) ==\")\n",
        "    print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "    print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "bEn0MUxTp2HA",
      "metadata": {
        "id": "bEn0MUxTp2HA"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "markdown",
      "id": "cU92htjS8Ou1",
      "metadata": {
        "id": "cU92htjS8Ou1"
      },
      "source": [
        "# **Boston Dataset**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a7f55d63-a72b-46ff-9811-9f1ffae49d53",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "a7f55d63-a72b-46ff-9811-9f1ffae49d53",
        "outputId": "75cfb3ef-1bc0-4164-d491-87c64fa4c1a1"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Housing] X.shape=(506, 13) y.shape=(506,) | y∈[5.000,50.000]\n",
            "[01/20] GBDT best_iter=289  MoE best_ep=10  TestRMSE(orig)=2.5337  TestNLL(z)=0.3435\n",
            "[02/20] GBDT best_iter=305  MoE best_ep=10  TestRMSE(orig)=2.5833  TestNLL(z)=0.2538\n",
            "[03/20] GBDT best_iter=545  MoE best_ep=11  TestRMSE(orig)=2.1390  TestNLL(z)=0.1778\n",
            "[04/20] GBDT best_iter=455  MoE best_ep=13  TestRMSE(orig)=3.8278  TestNLL(z)=0.2141\n",
            "[05/20] GBDT best_iter=253  MoE best_ep=14  TestRMSE(orig)=3.7482  TestNLL(z)=1.1760\n",
            "[06/20] GBDT best_iter=298  MoE best_ep=10  TestRMSE(orig)=2.5234  TestNLL(z)=0.2582\n",
            "[07/20] GBDT best_iter=421  MoE best_ep=10  TestRMSE(orig)=2.1363  TestNLL(z)=0.1458\n",
            "[08/20] GBDT best_iter=234  MoE best_ep=11  TestRMSE(orig)=2.6484  TestNLL(z)=0.3162\n",
            "[09/20] GBDT best_iter=552  MoE best_ep=9  TestRMSE(orig)=3.5282  TestNLL(z)=0.8218\n",
            "[10/20] GBDT best_iter=416  MoE best_ep=11  TestRMSE(orig)=4.8476  TestNLL(z)=1.1085\n",
            "[11/20] GBDT best_iter=388  MoE best_ep=17  TestRMSE(orig)=3.5469  TestNLL(z)=2.2304\n",
            "[12/20] GBDT best_iter=277  MoE best_ep=13  TestRMSE(orig)=2.9400  TestNLL(z)=0.5324\n",
            "[13/20] GBDT best_iter=497  MoE best_ep=11  TestRMSE(orig)=2.8267  TestNLL(z)=0.1934\n",
            "[14/20] GBDT best_iter=617  MoE best_ep=10  TestRMSE(orig)=3.3176  TestNLL(z)=0.8125\n",
            "[15/20] GBDT best_iter=501  MoE best_ep=11  TestRMSE(orig)=3.0635  TestNLL(z)=0.3826\n",
            "[16/20] GBDT best_iter=768  MoE best_ep=8  TestRMSE(orig)=2.8972  TestNLL(z)=0.5998\n",
            "[17/20] GBDT best_iter=174  MoE best_ep=10  TestRMSE(orig)=2.6565  TestNLL(z)=0.3500\n",
            "[18/20] GBDT best_iter=469  MoE best_ep=10  TestRMSE(orig)=2.8840  TestNLL(z)=0.4439\n",
            "[19/20] GBDT best_iter=295  MoE best_ep=10  TestRMSE(orig)=2.8686  TestNLL(z)=0.4361\n",
            "[20/20] GBDT best_iter=288  MoE best_ep=13  TestRMSE(orig)=2.7764  TestNLL(z)=1.1040\n",
            "\n",
            "== Anchor-MoE (coupled anchor feature) on Housing ==\n",
            "RMSE (orig) = 3.0147 ± 0.1434\n",
            "NLL  (z)    = 0.5950 ± 0.1129\n"
          ]
        }
      ],
      "source": [
        "\n",
        "X, y = load_dataset(\"housing\")\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = np.random.choice(np.arange(n), n, replace=False)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS = 1e-3, 400\n",
        "MEAN_MODE, DELTA_L2 = 'anchor+delta', 3e-3\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, m_it, moe_ep = train_one_split_with_anchor(\n",
        "        X[tr_idx], y[tr_idx], X[te_idx], y[te_idx],\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        MEAN_MODE=MEAN_MODE, DELTA_L2=DELTA_L2,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] GBDT best_iter={m_it}  MoE best_ep={moe_ep}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses, nlls = np.array(rmses, np.float64), np.array(nlls, np.float64)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== Anchor-MoE (coupled anchor feature) on Housing ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "-KyJRnFu8iML",
      "metadata": {
        "id": "-KyJRnFu8iML"
      },
      "source": [
        "# **Concrete**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "db16b565-6a85-4eb7-a7a3-90d7f8e0d392",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "db16b565-6a85-4eb7-a7a3-90d7f8e0d392",
        "outputId": "87122b46-0857-4122-a3ff-0fb61490bc09"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Concrete] X.shape=(1030, 8) y.shape=(1030,) | y∈[2.332,82.599]\n",
            "[01/20] GBDT best_iter=1981  MoE best_ep=12  TestRMSE(orig)=4.9270  TestNLL(z)=0.2130\n",
            "[02/20] GBDT best_iter=1823  MoE best_ep=17  TestRMSE(orig)=3.3315  TestNLL(z)=-0.0048\n",
            "[03/20] GBDT best_iter=1204  MoE best_ep=17  TestRMSE(orig)=2.9753  TestNLL(z)=-0.3177\n",
            "[04/20] GBDT best_iter=1989  MoE best_ep=15  TestRMSE(orig)=4.3188  TestNLL(z)=0.0002\n",
            "[05/20] GBDT best_iter=2000  MoE best_ep=14  TestRMSE(orig)=4.9339  TestNLL(z)=0.4477\n",
            "[06/20] GBDT best_iter=1999  MoE best_ep=14  TestRMSE(orig)=3.8787  TestNLL(z)=0.1072\n",
            "[07/20] GBDT best_iter=2000  MoE best_ep=15  TestRMSE(orig)=5.4025  TestNLL(z)=0.7132\n",
            "[08/20] GBDT best_iter=1949  MoE best_ep=16  TestRMSE(orig)=4.7943  TestNLL(z)=0.1700\n",
            "[09/20] GBDT best_iter=1523  MoE best_ep=13  TestRMSE(orig)=4.1883  TestNLL(z)=-0.0573\n",
            "[10/20] GBDT best_iter=779  MoE best_ep=18  TestRMSE(orig)=5.1184  TestNLL(z)=0.5029\n",
            "[11/20] GBDT best_iter=943  MoE best_ep=15  TestRMSE(orig)=4.0216  TestNLL(z)=-0.0626\n",
            "[12/20] GBDT best_iter=1643  MoE best_ep=15  TestRMSE(orig)=4.3564  TestNLL(z)=0.2538\n",
            "[13/20] GBDT best_iter=1883  MoE best_ep=11  TestRMSE(orig)=4.5319  TestNLL(z)=0.1432\n",
            "[14/20] GBDT best_iter=1238  MoE best_ep=18  TestRMSE(orig)=3.5098  TestNLL(z)=0.4963\n",
            "[15/20] GBDT best_iter=1209  MoE best_ep=18  TestRMSE(orig)=4.7404  TestNLL(z)=0.6216\n",
            "[16/20] GBDT best_iter=1969  MoE best_ep=15  TestRMSE(orig)=5.7473  TestNLL(z)=0.4292\n",
            "[17/20] GBDT best_iter=169  MoE best_ep=16  TestRMSE(orig)=4.5873  TestNLL(z)=0.3247\n",
            "[18/20] GBDT best_iter=1447  MoE best_ep=14  TestRMSE(orig)=4.3549  TestNLL(z)=0.1012\n",
            "[19/20] GBDT best_iter=702  MoE best_ep=13  TestRMSE(orig)=5.3311  TestNLL(z)=0.6796\n",
            "[20/20] GBDT best_iter=1997  MoE best_ep=16  TestRMSE(orig)=3.9467  TestNLL(z)=0.2506\n",
            "\n",
            "== Anchor-MoE (coupled anchor feature) on Housing ==\n",
            "RMSE (orig) = 4.4498 ± 0.1596\n",
            "NLL  (z)    = 0.2506 ± 0.0616\n"
          ]
        }
      ],
      "source": [
        "\n",
        "X, y = load_dataset(\"concrete\")\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = np.random.choice(np.arange(n), n, replace=False)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS = 1e-3, 400\n",
        "MEAN_MODE, DELTA_L2 = 'anchor+delta', 3e-3\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, m_it, moe_ep = train_one_split_with_anchor(\n",
        "        X[tr_idx], y[tr_idx], X[te_idx], y[te_idx],\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        MEAN_MODE=MEAN_MODE, DELTA_L2=DELTA_L2,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] GBDT best_iter={m_it}  MoE best_ep={moe_ep}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses, nlls = np.array(rmses, np.float64), np.array(nlls, np.float64)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== Anchor-MoE (coupled anchor feature) on concrete ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "lhh8jEdF8dRI",
      "metadata": {
        "id": "lhh8jEdF8dRI"
      },
      "source": [
        "# **Energy**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b7277b32-277d-4f7f-b33e-091124ee003b",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "b7277b32-277d-4f7f-b33e-091124ee003b",
        "outputId": "b84cdec9-af07-4abb-8ab0-7726b539f4c4"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Energy] X.shape=(768, 8) y.shape=(768,) | y∈[6.010,43.100]\n",
            "[01/20] GBDT best_iter=1986  MoE best_ep=172  TestRMSE(orig)=0.3912  TestNLL(z)=-1.7609\n",
            "[02/20] GBDT best_iter=1826  MoE best_ep=326  TestRMSE(orig)=0.5551  TestNLL(z)=-1.4699\n",
            "[03/20] GBDT best_iter=1982  MoE best_ep=231  TestRMSE(orig)=0.3986  TestNLL(z)=-1.7945\n",
            "[04/20] GBDT best_iter=2000  MoE best_ep=391  TestRMSE(orig)=0.4544  TestNLL(z)=-1.6483\n",
            "[05/20] GBDT best_iter=1828  MoE best_ep=186  TestRMSE(orig)=0.4850  TestNLL(z)=-1.6297\n",
            "[06/20] GBDT best_iter=1962  MoE best_ep=257  TestRMSE(orig)=0.3891  TestNLL(z)=-1.7654\n",
            "[07/20] GBDT best_iter=1949  MoE best_ep=40  TestRMSE(orig)=0.7941  TestNLL(z)=-1.5950\n",
            "[08/20] GBDT best_iter=1996  MoE best_ep=130  TestRMSE(orig)=0.3606  TestNLL(z)=-1.7804\n",
            "[09/20] GBDT best_iter=1999  MoE best_ep=221  TestRMSE(orig)=0.4991  TestNLL(z)=-1.5604\n",
            "[10/20] GBDT best_iter=1706  MoE best_ep=376  TestRMSE(orig)=0.3565  TestNLL(z)=-1.7761\n",
            "[11/20] GBDT best_iter=1652  MoE best_ep=85  TestRMSE(orig)=0.3821  TestNLL(z)=-1.7350\n",
            "[12/20] GBDT best_iter=568  MoE best_ep=366  TestRMSE(orig)=0.3994  TestNLL(z)=-1.7969\n",
            "[13/20] GBDT best_iter=1985  MoE best_ep=164  TestRMSE(orig)=0.6501  TestNLL(z)=-1.3871\n",
            "[14/20] GBDT best_iter=1560  MoE best_ep=297  TestRMSE(orig)=0.4537  TestNLL(z)=-1.6515\n",
            "[15/20] GBDT best_iter=1984  MoE best_ep=88  TestRMSE(orig)=0.3906  TestNLL(z)=-1.8308\n",
            "[16/20] GBDT best_iter=679  MoE best_ep=285  TestRMSE(orig)=0.3743  TestNLL(z)=-1.7858\n",
            "[17/20] GBDT best_iter=1518  MoE best_ep=157  TestRMSE(orig)=0.4994  TestNLL(z)=-1.5742\n",
            "[18/20] GBDT best_iter=1908  MoE best_ep=286  TestRMSE(orig)=0.4548  TestNLL(z)=-1.6611\n",
            "[19/20] GBDT best_iter=1361  MoE best_ep=168  TestRMSE(orig)=0.6094  TestNLL(z)=-1.6579\n",
            "[20/20] GBDT best_iter=1993  MoE best_ep=71  TestRMSE(orig)=0.4251  TestNLL(z)=-1.7301\n",
            "\n",
            "== Anchor-MoE (coupled anchor feature) on Housing ==\n",
            "RMSE (orig) = 0.4661 ± 0.0251\n",
            "NLL  (z)    = -1.6796 ± 0.0265\n"
          ]
        }
      ],
      "source": [
        "\n",
        "X, y = load_dataset(\"energy\")\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = np.random.choice(np.arange(n), n, replace=False)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS = 1e-3, 400\n",
        "MEAN_MODE, DELTA_L2 = 'anchor+delta', 3e-3\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, m_it, moe_ep = train_one_split_with_anchor(\n",
        "        X[tr_idx], y[tr_idx], X[te_idx], y[te_idx],\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        MEAN_MODE=MEAN_MODE, DELTA_L2=DELTA_L2,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] GBDT best_iter={m_it}  MoE best_ep={moe_ep}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses, nlls = np.array(rmses, np.float64), np.array(nlls, np.float64)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== Anchor-MoE (coupled anchor feature) on Energy ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "zjwAnBwaVgzb",
      "metadata": {
        "id": "zjwAnBwaVgzb"
      },
      "source": [
        "# **Kin8nm**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "avMJRv0AkhIm",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "avMJRv0AkhIm",
        "outputId": "adf39249-6367-4687-f65f-dedf22ce971c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Kin8nm] (fallback) X.shape=(8192, 8) y.shape=(8192,) | y∈[0.040,1.459]\n",
            "== Dataset=Kin8nm X.shape=(8192, 8) y.shape=(8192,) ==\n",
            "[01/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0744  TestNLL(z)=0.1546\n",
            "[02/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0735  TestNLL(z)=0.1298\n",
            "[03/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0724  TestNLL(z)=0.1181\n",
            "[04/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0722  TestNLL(z)=0.1047\n",
            "[05/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0715  TestNLL(z)=0.1026\n",
            "[06/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0775  TestNLL(z)=0.2042\n",
            "[07/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0737  TestNLL(z)=0.1404\n",
            "[08/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0743  TestNLL(z)=0.1366\n",
            "[09/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0681  TestNLL(z)=0.0575\n",
            "[10/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0776  TestNLL(z)=0.1769\n",
            "[11/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0692  TestNLL(z)=0.0741\n",
            "[12/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0719  TestNLL(z)=0.1069\n",
            "[13/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0772  TestNLL(z)=0.2103\n",
            "[14/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0708  TestNLL(z)=0.0876\n",
            "[15/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0735  TestNLL(z)=0.1370\n",
            "[16/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0699  TestNLL(z)=0.0775\n",
            "[17/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0738  TestNLL(z)=0.1475\n",
            "[18/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0726  TestNLL(z)=0.1192\n",
            "[19/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0701  TestNLL(z)=0.0904\n",
            "[20/20] K=20  RFF=4096  lambda=1.0e-03 | TestRMSE=0.0722  TestNLL(z)=0.1014\n",
            "\n",
            "== RFF-KRR Mixture (GP-inspired MoE) on Kin8nm ==\n",
            "RMSE (orig) = 0.0728 ± 0.0006\n",
            "NLL  (z)    = 0.1239 ± 0.0092\n"
          ]
        }
      ],
      "source": [
        "import numpy as np, math, gc\n",
        "from sklearn.cluster import KMeans\n",
        "from sklearn.preprocessing import StandardScaler\n",
        "\n",
        "\n",
        "try:\n",
        "    X, y = load_dataset(\"kin8nm\")  \n",
        "except NameError:\n",
        "    import pandas as pd, os\n",
        "    assert os.path.exists(\"kin8nm.csv\"), \"kin8nm.csv\"\n",
        "    df = pd.read_csv(\"kin8nm.csv\")\n",
        "    X = df.iloc[:, :-1].to_numpy(np.float64)\n",
        "    y = df.iloc[:,  -1].to_numpy(np.float64)\n",
        "    print(f\"[Kin8nm] (fallback) X.shape={X.shape} y.shape={y.shape} | y∈[{y.min():.3f},{y.max():.3f}]\")\n",
        "print(f\"== Dataset=Kin8nm X.shape={X.shape} y.shape={y.shape} ==\")\n",
        "\n",
        "SUBSAMPLE = None   \n",
        "if SUBSAMPLE is not None and X.shape[0] > SUBSAMPLE:\n",
        "    idx = np.random.RandomState(1).choice(X.shape[0], SUBSAMPLE, replace=False)\n",
        "    X, y = X[idx], y[idx]\n",
        "    print(f\"[Kin8nm] subsampled: X.shape={X.shape}, y.shape={y.shape}\")\n",
        "\n",
        "SEED = 1\n",
        "rng = np.random.RandomState(SEED)\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for _ in range(20):\n",
        "    perm = rng.permutation(n)\n",
        "    splits.append((perm[: round(n*0.9)], perm[round(n*0.9):]))\n",
        "\n",
        "def rmse(a, b): return float(np.sqrt(np.mean((a - b)**2)))\n",
        "def zfit(y_tr):\n",
        "    my = float(y_tr.mean()); sy = float(y_tr.std() + 1e-8)\n",
        "    return my, sy\n",
        "\n",
        "class RFF:\n",
        "    def __init__(self, d_in, n_feat=256, lengthscale=1.0, seed=0):\n",
        "        rs = np.random.RandomState(seed)\n",
        "        self.n_feat = int(n_feat)\n",
        "        self.lengthscale = float(lengthscale)\n",
        "        self.W = rs.normal(size=(d_in, self.n_feat)) / max(self.lengthscale, 1e-6)\n",
        "        self.b = rs.uniform(0.0, 2*np.pi, size=(self.n_feat,))\n",
        "        self.scale = np.sqrt(2.0 / self.n_feat)\n",
        "    def features(self, X):  \n",
        "        Z = X @ self.W + self.b\n",
        "        return self.scale * np.cos(Z) \n",
        "\n",
        "def weighted_ridge(Phi, y, w=None, lam=1e-3):\n",
        "    \n",
        "    if w is None:\n",
        "        A = Phi.T @ Phi + lam * np.eye(Phi.shape[1])\n",
        "        b = Phi.T @ y\n",
        "    else:\n",
        "        W = w[:, None]\n",
        "        A = Phi.T @ (W * Phi) + lam * np.eye(Phi.shape[1])\n",
        "        b = Phi.T @ (w * y)\n",
        "    w_lin = np.linalg.solve(A, b)\n",
        "    return w_lin, A \n",
        "\n",
        "def mixture_loglik_gaussians(yz, mus, sig2, mix_w):\n",
        "    \n",
        "    LOG2PI = math.log(2*math.pi)\n",
        "    var = sig2[None, :]\n",
        "    ll = -0.5*((yz[:, None] - mus)**2)/var - 0.5*np.log(var) - 0.5*LOG2PI\n",
        "    a = np.log(mix_w + 1e-12) + ll\n",
        "    a_max = a.max(axis=1, keepdims=True)\n",
        "    logp = a_max[:,0] + np.log(np.exp(a - a_max).sum(axis=1) + 1e-12)\n",
        "    return logp\n",
        "\n",
        "def run_one_split(X_tr_full, y_tr_full, X_te, y_te,\n",
        "                  K=3, n_feat=256, lam=1e-3, topk=None,\n",
        "                  use_anchor=False,  \n",
        "                  random_state=0):\n",
        "    scaler = StandardScaler().fit(X_tr_full)\n",
        "    Xtr = scaler.transform(X_tr_full)\n",
        "    Xte = scaler.transform(X_te)\n",
        "\n",
        "    my, sy = zfit(y_tr_full)\n",
        "    ytr_z = (y_tr_full - my) / sy\n",
        "    yte_z = (y_te      - my) / sy\n",
        "\n",
        "    if use_anchor:\n",
        "        from sklearn.ensemble import GradientBoostingRegressor\n",
        "        gbdt = GradientBoostingRegressor(n_estimators=1500, learning_rate=0.05,\n",
        "                                         max_depth=3, random_state=random_state)\n",
        "        gbdt.fit(X_tr_full, y_tr_full)\n",
        "        mu_tr_anchor_z = (gbdt.predict(X_tr_full) - my) / sy\n",
        "        mu_te_anchor_z = (gbdt.predict(X_te)      - my) / sy\n",
        "        target_tr = ytr_z - mu_tr_anchor_z\n",
        "    else:\n",
        "        mu_tr_anchor_z = np.zeros_like(ytr_z)\n",
        "        mu_te_anchor_z = np.zeros_like(yte_z)\n",
        "        target_tr = ytr_z\n",
        "\n",
        "    km = KMeans(n_clusters=K, n_init=10, random_state=random_state).fit(Xtr)\n",
        "    C = km.cluster_centers_              \n",
        "    labels = km.labels_\n",
        "    bw2 = []\n",
        "    for k in range(K):\n",
        "        pts = Xtr[labels == k]\n",
        "        if pts.shape[0] < 2:\n",
        "            bw2.append(1.0)\n",
        "        else:\n",
        "            bw2.append(np.mean(np.sum((pts - C[k])**2, axis=1)) + 1e-6)\n",
        "    bw2 = np.asarray(bw2)  \n",
        "\n",
        "    def soft_assign(Xin):\n",
        "        D2 = np.sum((Xin[:, None, :] - C[None, :, :])**2, axis=2)  \n",
        "        logits = - D2 / (2.0 * bw2[None, :])\n",
        "        P = np.exp(logits - logits.max(axis=1, keepdims=True))\n",
        "        P = P / (P.sum(axis=1, keepdims=True) + 1e-12)\n",
        "        if topk is not None and topk < K:\n",
        "            idx = np.argsort(-P, axis=1)[:, :topk]\n",
        "            mask = np.zeros_like(P); mask[np.arange(P.shape[0])[:,None], idx] = 1\n",
        "            P = P * mask\n",
        "            P = P / (P.sum(axis=1, keepdims=True) + 1e-12)\n",
        "        return P  \n",
        "\n",
        "    W_tr = soft_assign(Xtr)   \n",
        "\n",
        "    N, d = Xtr.shape\n",
        "    experts = []\n",
        "    mus_tr = np.zeros((N, K))\n",
        "    for k in range(K):\n",
        "        rff = RFF(d_in=d, n_feat=n_feat, lengthscale=np.sqrt(bw2[k]), seed=random_state+100+k)\n",
        "        Phi = rff.features(Xtr)                        \n",
        "        w_k, A_k = weighted_ridge(Phi, target_tr, w=W_tr[:, k], lam=lam)\n",
        "        mu_k_tr  = Phi @ w_k                           \n",
        "        resid    = target_tr - mu_k_tr\n",
        "        wsum     = W_tr[:, k].sum() + 1e-12\n",
        "        sig2_k   = max(float((W_tr[:, k] * (resid**2)).sum() / wsum), 1e-8)\n",
        "        experts.append((rff, w_k, A_k, sig2_k))\n",
        "        mus_tr[:, k] = mu_k_tr\n",
        "\n",
        "    W_te = soft_assign(Xte)                     \n",
        "    Nt = Xte.shape[0]\n",
        "    mus_te = np.zeros((Nt, K))\n",
        "    sig2   = np.zeros((K,), dtype=np.float64)\n",
        "    for k, (rff, w_k, A_k, sig2_k) in enumerate(experts):\n",
        "        Phi_te = rff.features(Xte)\n",
        "        mus_te[:, k] = Phi_te @ w_k\n",
        "        sig2[k] = sig2_k\n",
        "        \n",
        "    mus_te_full = mus_te + mu_te_anchor_z[:, None]\n",
        "\n",
        "    logp = mixture_loglik_gaussians(yte_z, mus_te_full, sig2, W_te)\n",
        "    nll = -float(np.mean(logp))\n",
        "\n",
        "    mu_pred_z = np.sum(W_te * mus_te_full, axis=1)      \n",
        "    mu_pred   = mu_pred_z * sy + my\n",
        "    rmse_val  = rmse(mu_pred, y_te)\n",
        "\n",
        "    return rmse_val, nll\n",
        "\n",
        "K = 20             \n",
        "N_FEAT = 4096      \n",
        "LAMBDA = 1e-3    \n",
        "TOPK = K        \n",
        "USE_ANCHOR = None  \n",
        "\n",
        "rmses, nlls = [], []\n",
        "for i, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    X_tr, y_tr = X[tr_idx], y[tr_idx]\n",
        "    X_te, y_te = X[te_idx], y[te_idx]\n",
        "\n",
        "    r, n = run_one_split(\n",
        "        X_tr, y_tr, X_te, y_te,\n",
        "        K=K, n_feat=N_FEAT, lam=LAMBDA, topk=TOPK,\n",
        "        use_anchor=USE_ANCHOR, random_state=SEED + i\n",
        "    )\n",
        "    print(f\"[{i:02d}/20] K={K}  RFF={N_FEAT}  lambda={LAMBDA:.1e} | TestRMSE={r:.4f}  TestNLL(z)={n:.4f}\")\n",
        "    rmses.append(r); nlls.append(n); gc.collect()\n",
        "\n",
        "rmses, nlls = np.array(rmses, np.float64), np.array(nlls, np.float64)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "print(\"\\n== RFF-KRR Mixture (GP-inspired MoE) on Kin8nm ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "s6vinNBonX3W",
      "metadata": {
        "id": "s6vinNBonX3W"
      },
      "source": [
        "# **Naval**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "YuiXDynzhdXl",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "YuiXDynzhdXl",
        "outputId": "92ac271e-626c-4511-dfa6-4b058138d409"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Naval(local)] X.shape=(11934, 16) y.shape=(11934,) | y∈[0.950000,1.000000]\n",
            "[01/20] GBDT best_iter=1999  MoE best_ep=215  TestRMSE(orig)=0.0011  TestNLL(z)=-1.1336\n",
            "[02/20] GBDT best_iter=2000  MoE best_ep=322  TestRMSE(orig)=0.0009  TestNLL(z)=-1.3619\n",
            "[03/20] GBDT best_iter=2000  MoE best_ep=280  TestRMSE(orig)=0.0009  TestNLL(z)=-1.4272\n",
            "[04/20] GBDT best_iter=2000  MoE best_ep=110  TestRMSE(orig)=0.0010  TestNLL(z)=-1.2368\n",
            "[05/20] GBDT best_iter=2000  MoE best_ep=130  TestRMSE(orig)=0.0010  TestNLL(z)=-1.2014\n",
            "[06/20] GBDT best_iter=2000  MoE best_ep=372  TestRMSE(orig)=0.0010  TestNLL(z)=-1.3299\n",
            "[07/20] GBDT best_iter=2000  MoE best_ep=399  TestRMSE(orig)=0.0010  TestNLL(z)=-1.2364\n",
            "[08/20] GBDT best_iter=1998  MoE best_ep=235  TestRMSE(orig)=0.0010  TestNLL(z)=-1.2702\n",
            "[09/20] GBDT best_iter=1999  MoE best_ep=145  TestRMSE(orig)=0.0011  TestNLL(z)=-1.2576\n",
            "[10/20] GBDT best_iter=2000  MoE best_ep=369  TestRMSE(orig)=0.0010  TestNLL(z)=-1.2053\n",
            "[11/20] GBDT best_iter=2000  MoE best_ep=159  TestRMSE(orig)=0.0010  TestNLL(z)=-1.3139\n",
            "[12/20] GBDT best_iter=2000  MoE best_ep=378  TestRMSE(orig)=0.0010  TestNLL(z)=-1.1753\n",
            "[13/20] GBDT best_iter=2000  MoE best_ep=265  TestRMSE(orig)=0.0010  TestNLL(z)=-1.2196\n",
            "[14/20] GBDT best_iter=2000  MoE best_ep=93  TestRMSE(orig)=0.0010  TestNLL(z)=-1.2957\n",
            "[15/20] GBDT best_iter=1999  MoE best_ep=400  TestRMSE(orig)=0.0009  TestNLL(z)=-1.2998\n",
            "[16/20] GBDT best_iter=2000  MoE best_ep=400  TestRMSE(orig)=0.0010  TestNLL(z)=-1.2994\n",
            "[17/20] GBDT best_iter=2000  MoE best_ep=155  TestRMSE(orig)=0.0010  TestNLL(z)=-1.2006\n",
            "[18/20] GBDT best_iter=2000  MoE best_ep=138  TestRMSE(orig)=0.0010  TestNLL(z)=-1.3295\n",
            "[19/20] GBDT best_iter=2000  MoE best_ep=196  TestRMSE(orig)=0.0009  TestNLL(z)=-1.3124\n",
            "[20/20] GBDT best_iter=1999  MoE best_ep=301  TestRMSE(orig)=0.0010  TestNLL(z)=-1.2607\n",
            "\n",
            "== Anchor-MoE (coupled anchor feature) on Housing ==\n",
            "RMSE (orig) = 0.0010 ± 0.0000\n",
            "NLL  (z)    = -1.2683 ± 0.0156\n"
          ]
        }
      ],
      "source": [
        "\n",
        "X, y = load_dataset(\"naval\")\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = np.random.choice(np.arange(n), n, replace=False)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS = 1e-3, 400\n",
        "MEAN_MODE, DELTA_L2 = 'anchor+delta', 3e-3\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, m_it, moe_ep = train_one_split_with_anchor(\n",
        "        X[tr_idx], y[tr_idx], X[te_idx], y[te_idx],\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        MEAN_MODE=MEAN_MODE, DELTA_L2=DELTA_L2,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] GBDT best_iter={m_it}  MoE best_ep={moe_ep}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses, nlls = np.array(rmses, np.float64), np.array(nlls, np.float64)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== Anchor-MoE (coupled anchor feature) on Housing ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "KsLL0HHTvNvA",
      "metadata": {
        "id": "KsLL0HHTvNvA"
      },
      "source": [
        "# Power Plant"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "VmAx3-01x_yU",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VmAx3-01x_yU",
        "outputId": "9270ec3c-0814-4656-f9e3-0fb7af0b9fdc"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Power] X.shape=(9568, 4) y.shape=(9568,) | y∈[420.260,495.760]\n",
            "[01/20] GBDT best_iter=1999  MoE best_ep=20  TestRMSE(orig)=3.2548  TestNLL(z)=-0.1435\n",
            "[02/20] GBDT best_iter=1998  MoE best_ep=20  TestRMSE(orig)=3.3316  TestNLL(z)=0.0056\n",
            "[03/20] GBDT best_iter=1997  MoE best_ep=21  TestRMSE(orig)=2.9826  TestNLL(z)=-0.1756\n",
            "[04/20] GBDT best_iter=1994  MoE best_ep=42  TestRMSE(orig)=3.0454  TestNLL(z)=-0.2052\n",
            "[05/20] GBDT best_iter=1998  MoE best_ep=22  TestRMSE(orig)=3.6840  TestNLL(z)=-0.1179\n",
            "[06/20] GBDT best_iter=1992  MoE best_ep=26  TestRMSE(orig)=3.4787  TestNLL(z)=-0.0480\n",
            "[07/20] GBDT best_iter=1999  MoE best_ep=18  TestRMSE(orig)=3.1511  TestNLL(z)=-0.1883\n",
            "[08/20] GBDT best_iter=1999  MoE best_ep=27  TestRMSE(orig)=3.0624  TestNLL(z)=-0.0496\n",
            "[09/20] GBDT best_iter=1991  MoE best_ep=20  TestRMSE(orig)=3.4551  TestNLL(z)=-0.1624\n",
            "[10/20] GBDT best_iter=1996  MoE best_ep=21  TestRMSE(orig)=3.3850  TestNLL(z)=-0.1650\n",
            "[11/20] GBDT best_iter=1999  MoE best_ep=24  TestRMSE(orig)=3.0285  TestNLL(z)=-0.1376\n",
            "[12/20] GBDT best_iter=1990  MoE best_ep=19  TestRMSE(orig)=3.4126  TestNLL(z)=0.0737\n",
            "[13/20] GBDT best_iter=2000  MoE best_ep=51  TestRMSE(orig)=3.0362  TestNLL(z)=-0.1869\n",
            "[14/20] GBDT best_iter=1998  MoE best_ep=55  TestRMSE(orig)=2.8220  TestNLL(z)=-0.1912\n",
            "[15/20] GBDT best_iter=2000  MoE best_ep=50  TestRMSE(orig)=3.3359  TestNLL(z)=-0.2726\n",
            "[16/20] GBDT best_iter=1998  MoE best_ep=24  TestRMSE(orig)=2.8244  TestNLL(z)=-0.2831\n",
            "[17/20] GBDT best_iter=2000  MoE best_ep=51  TestRMSE(orig)=2.9607  TestNLL(z)=-0.2732\n",
            "[18/20] GBDT best_iter=2000  MoE best_ep=18  TestRMSE(orig)=3.4793  TestNLL(z)=-0.1314\n",
            "[19/20] GBDT best_iter=1987  MoE best_ep=51  TestRMSE(orig)=3.4628  TestNLL(z)=-0.1825\n",
            "[20/20] GBDT best_iter=1986  MoE best_ep=19  TestRMSE(orig)=3.0958  TestNLL(z)=-0.1940\n",
            "\n",
            "== Anchor-MoE (coupled anchor feature) on Housing ==\n",
            "RMSE (orig) = 3.2144 ± 0.0547\n",
            "NLL  (z)    = -0.1514 ± 0.0202\n"
          ]
        }
      ],
      "source": [
        "\n",
        "X, y = load_dataset(\"power\")\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = np.random.choice(np.arange(n), n, replace=False)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS = 1e-3, 400\n",
        "MEAN_MODE, DELTA_L2 = 'anchor+delta', 3e-3\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, m_it, moe_ep = train_one_split_with_anchor(\n",
        "        X[tr_idx], y[tr_idx], X[te_idx], y[te_idx],\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        MEAN_MODE=MEAN_MODE, DELTA_L2=DELTA_L2,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] GBDT best_iter={m_it}  MoE best_ep={moe_ep}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses, nlls = np.array(rmses, np.float64), np.array(nlls, np.float64)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== Anchor-MoE (coupled anchor feature) on Housing ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "KeWUfFxgzIC3",
      "metadata": {
        "id": "KeWUfFxgzIC3"
      },
      "source": [
        "# **Protein**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "SGfcnVAmzPm1",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "SGfcnVAmzPm1",
        "outputId": "8bc04556-d234-4c9e-b680-69c9f2949787"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Protein(local)] X.shape=(45730, 9) y.shape=(45730,) | y∈[0.000,20.999]\n",
            "[Protein] subsampled to 10000 from 45730\n",
            "[01/20] GBDT best_iter=1820  MoE best_ep=68  TestRMSE(orig)=4.2760  TestNLL(z)=0.9240\n",
            "[02/20] GBDT best_iter=1986  MoE best_ep=6  TestRMSE(orig)=4.4256  TestNLL(z)=1.1784\n",
            "[03/20] GBDT best_iter=1685  MoE best_ep=4  TestRMSE(orig)=4.3811  TestNLL(z)=1.0640\n",
            "[04/20] GBDT best_iter=1721  MoE best_ep=57  TestRMSE(orig)=4.4167  TestNLL(z)=1.0832\n",
            "[05/20] GBDT best_iter=1998  MoE best_ep=5  TestRMSE(orig)=4.4133  TestNLL(z)=1.0971\n",
            "[06/20] GBDT best_iter=1495  MoE best_ep=5  TestRMSE(orig)=4.3702  TestNLL(z)=1.0654\n",
            "[07/20] GBDT best_iter=1178  MoE best_ep=114  TestRMSE(orig)=4.5190  TestNLL(z)=0.8933\n",
            "[08/20] GBDT best_iter=1995  MoE best_ep=2  TestRMSE(orig)=4.3996  TestNLL(z)=1.1519\n",
            "[09/20] GBDT best_iter=1704  MoE best_ep=7  TestRMSE(orig)=4.3023  TestNLL(z)=1.1320\n",
            "[10/20] GBDT best_iter=1965  MoE best_ep=5  TestRMSE(orig)=4.4172  TestNLL(z)=1.0711\n",
            "[11/20] GBDT best_iter=1667  MoE best_ep=65  TestRMSE(orig)=4.4366  TestNLL(z)=0.8475\n",
            "[12/20] GBDT best_iter=1996  MoE best_ep=6  TestRMSE(orig)=4.3781  TestNLL(z)=1.0513\n",
            "[13/20] GBDT best_iter=1046  MoE best_ep=104  TestRMSE(orig)=4.3811  TestNLL(z)=1.2130\n",
            "[14/20] GBDT best_iter=1948  MoE best_ep=6  TestRMSE(orig)=4.2935  TestNLL(z)=1.0406\n",
            "[15/20] GBDT best_iter=1660  MoE best_ep=64  TestRMSE(orig)=4.5287  TestNLL(z)=1.0147\n",
            "[16/20] GBDT best_iter=1809  MoE best_ep=5  TestRMSE(orig)=4.4236  TestNLL(z)=1.1228\n",
            "[17/20] GBDT best_iter=1544  MoE best_ep=73  TestRMSE(orig)=4.6718  TestNLL(z)=1.4097\n",
            "[18/20] GBDT best_iter=1878  MoE best_ep=65  TestRMSE(orig)=4.4248  TestNLL(z)=1.0120\n",
            "[19/20] GBDT best_iter=1808  MoE best_ep=6  TestRMSE(orig)=4.5208  TestNLL(z)=1.1146\n",
            "[20/20] GBDT best_iter=1778  MoE best_ep=67  TestRMSE(orig)=4.2482  TestNLL(z)=0.7410\n",
            "\n",
            "== Anchor-MoE (coupled anchor feature) on Protein (subsampled to 10k) ==\n",
            "RMSE (orig) = 4.4114 ± 0.0219\n",
            "NLL  (z)    = 1.0614 ± 0.0316\n"
          ]
        }
      ],
      "source": [
        "import numpy as np, gc, torch\n",
        "\n",
        "X, y = load_dataset(\"protein\") \n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "N_SUB = 10000\n",
        "n_full = X.shape[0]\n",
        "if n_full > N_SUB:\n",
        "    rng_sub = np.random.RandomState(SEED)\n",
        "    idx_sub = rng_sub.choice(n_full, N_SUB, replace=False)\n",
        "    X, y = X[idx_sub], y[idx_sub]\n",
        "    print(f\"[Protein] subsampled to {N_SUB} from {n_full}\")\n",
        "else:\n",
        "    print(f\"[Protein] dataset size {n_full} ≤ {N_SUB}, use full data\")\n",
        "\n",
        "np.random.seed(SEED)\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = np.random.choice(np.arange(n), n, replace=False)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS = 1e-3, 400\n",
        "MEAN_MODE, DELTA_L2 = 'anchor+delta', 3e-3\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, m_it, moe_ep = train_one_split_with_anchor(\n",
        "        X[tr_idx], y[tr_idx], X[te_idx], y[te_idx],\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        MEAN_MODE=MEAN_MODE, DELTA_L2=DELTA_L2,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] GBDT best_iter={m_it}  MoE best_ep={moe_ep}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses, nlls = np.array(rmses, np.float64), np.array(nlls, np.float64)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== Anchor-MoE (coupled anchor feature) on Protein (subsampled to 10k) ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "sJZlYKZ005sQ",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "sJZlYKZ005sQ",
        "outputId": "6389ddb6-ecbd-4597-8ae3-a6cc50e559c5"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Protein(local)] X.shape=(45730, 9) y.shape=(45730,) | y∈[0.000,20.999]\n",
            "== Dataset=Protein(local) X.shape=(45730, 9) y.shape=(45730,) ==\n",
            "[Protein] subsampled to 10000 from 45730\n",
            "[01/20] TestRMSE=4.2750  TestNLL(z)=1.2519\n",
            "[02/20] TestRMSE=4.5089  TestNLL(z)=1.2483\n",
            "[03/20] TestRMSE=4.4202  TestNLL(z)=1.3973\n",
            "[04/20] TestRMSE=4.3900  TestNLL(z)=1.1127\n",
            "[05/20] TestRMSE=4.3676  TestNLL(z)=1.6895\n",
            "[06/20] TestRMSE=4.4145  TestNLL(z)=1.3132\n",
            "[07/20] TestRMSE=4.4981  TestNLL(z)=1.3205\n",
            "[08/20] TestRMSE=4.4001  TestNLL(z)=1.1653\n",
            "[09/20] TestRMSE=4.3782  TestNLL(z)=1.0529\n",
            "[10/20] TestRMSE=4.4477  TestNLL(z)=1.1204\n",
            "[11/20] TestRMSE=4.4191  TestNLL(z)=1.0913\n",
            "[12/20] TestRMSE=4.4724  TestNLL(z)=1.3240\n",
            "[13/20] TestRMSE=4.3791  TestNLL(z)=1.1726\n",
            "[14/20] TestRMSE=4.4829  TestNLL(z)=0.9971\n",
            "[15/20] TestRMSE=4.5001  TestNLL(z)=1.3051\n",
            "[16/20] TestRMSE=4.4920  TestNLL(z)=1.1451\n",
            "[17/20] TestRMSE=4.6766  TestNLL(z)=1.2180\n",
            "[18/20] TestRMSE=4.5327  TestNLL(z)=1.3616\n",
            "[19/20] TestRMSE=4.5361  TestNLL(z)=1.4836\n",
            "[20/20] TestRMSE=4.2321  TestNLL(z)=1.1184\n",
            "\n",
            "== NGBoost (Normal) on Protein (subsampled to 10k) ==\n",
            "RMSE (orig) = 4.4412 ± 0.0218\n",
            "NLL  (z)    = 1.2444 ± 0.0364\n"
          ]
        }
      ],
      "source": [
        "import numpy as np, math, gc, os, sys, subprocess\n",
        "\n",
        "# 1) ensure ngboost\n",
        "try:\n",
        "    from ngboost import NGBRegressor\n",
        "    from ngboost.distns import Normal\n",
        "    from ngboost.scores import MLE\n",
        "except Exception as e:\n",
        "    print(\"[setup] Installing ngboost ...\")\n",
        "    subprocess.check_call([sys._getframe(0).f_globals.get(\"sys\").executable, \"-m\", \"pip\", \"install\", \"ngboost\"])\n",
        "    from ngboost import NGBRegressor\n",
        "    from ngboost.distns import Normal\n",
        "    from ngboost.scores import MLE\n",
        "\n",
        "try:\n",
        "    X, y = load_dataset(\"protein\")  \n",
        "except NameError:\n",
        "    import pandas as pd\n",
        "    assert os.path.exists(\"protein.csv\"), \"缺少本地 protein.csv（y， X）\"\n",
        "    df = pd.read_csv(\"protein.csv\")\n",
        "    y = df.iloc[:, 0].to_numpy(np.float64)\n",
        "    X = df.iloc[:, 1:].to_numpy(np.float64)\n",
        "print(f\"== Dataset=Protein(local) X.shape={X.shape} y.shape={y.shape} ==\")\n",
        "\n",
        "SEED = 1\n",
        "RS = np.random.RandomState(SEED)\n",
        "N_SUB = 10000\n",
        "n_full = X.shape[0]\n",
        "if n_full > N_SUB:\n",
        "    idx_sub = RS.choice(n_full, N_SUB, replace=False)\n",
        "    X, y = X[idx_sub], y[idx_sub]\n",
        "    print(f\"[Protein] subsampled to {N_SUB} from {n_full}\")\n",
        "else:\n",
        "    print(f\"[Protein] dataset size {n_full} ≤ {N_SUB}, use full data\")\n",
        "\n",
        "RS = np.random.RandomState(SEED)  \n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for _ in range(20):\n",
        "    perm = RS.permutation(n)\n",
        "    splits.append((perm[: round(n*0.9)], perm[round(n*0.9):]))\n",
        "\n",
        "def rmse(a, b): return float(np.sqrt(np.mean((a - b)**2)))\n",
        "\n",
        "def zfit(y_tr):\n",
        "    my = float(y_tr.mean()); sy = float(y_tr.std() + 1e-8)\n",
        "    return my, sy\n",
        "\n",
        "LOG2PI = math.log(2*math.pi)\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for i, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    X_tr, y_tr = X[tr_idx], y[tr_idx]\n",
        "    X_te, y_te = X[te_idx], y[te_idx]\n",
        "\n",
        "    \n",
        "    my, sy = zfit(y_tr)\n",
        "    y_te_z = (y_te - my) / sy\n",
        "\n",
        "   \n",
        "    ngb = NGBRegressor(\n",
        "        Dist=Normal,\n",
        "        Score=MLE,\n",
        "        n_estimators=2000,        \n",
        "        learning_rate=0.05,      \n",
        "        natural_gradient=True,\n",
        "        random_state=SEED + i,\n",
        "        verbose=False\n",
        "    )\n",
        "    ngb.fit(X_tr, y_tr)\n",
        "\n",
        "    dist = ngb.pred_dist(X_te)     \n",
        "    mu  = dist.loc                 \n",
        "    sig = dist.scale               \n",
        "    sig = np.maximum(sig, 1e-12)\n",
        "\n",
        "    mu_z  = (mu - my) / sy\n",
        "    sig_z = sig / sy\n",
        "    sig2_z = np.maximum(sig_z**2, 1e-12)\n",
        "\n",
        "    nll = 0.5*np.log(sig2_z) + 0.5*((y_te_z - mu_z)**2)/sig2_z + 0.5*LOG2PI\n",
        "    nll = float(np.mean(nll))\n",
        "\n",
        "    rmse_val = rmse(mu, y_te)\n",
        "\n",
        "    print(f\"[{i:02d}/20] TestRMSE={rmse_val:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse_val); nlls.append(nll)\n",
        "    gc.collect()\n",
        "\n",
        "rmses = np.array(rmses, np.float64); nlls = np.array(nlls, np.float64)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "print(\"\\n== NGBoost (Normal) on Protein (subsampled to 10k) ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "m1qQtWhN6ruX",
      "metadata": {
        "id": "m1qQtWhN6ruX"
      },
      "source": [
        "# **Wine Quality**\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "VKmSHh5r9RDu",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VKmSHh5r9RDu",
        "outputId": "22be12b2-1fc5-43b7-fb69-8b8b19addc9d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[WineRed] X.shape=(1599, 11) y.shape=(1599,) | y∈[3.000,8.000]\n",
            "[01/20] GBDT best_iter=390  MoE best_ep=3  TestRMSE(orig)=0.5995  TestNLL(z)=1.1511\n",
            "[02/20] GBDT best_iter=229  MoE best_ep=3  TestRMSE(orig)=0.5963  TestNLL(z)=1.1316\n",
            "[03/20] GBDT best_iter=96  MoE best_ep=2  TestRMSE(orig)=0.7258  TestNLL(z)=1.3395\n",
            "[04/20] GBDT best_iter=95  MoE best_ep=2  TestRMSE(orig)=0.5908  TestNLL(z)=1.1001\n",
            "[05/20] GBDT best_iter=437  MoE best_ep=3  TestRMSE(orig)=0.5632  TestNLL(z)=1.1395\n",
            "[06/20] GBDT best_iter=337  MoE best_ep=3  TestRMSE(orig)=0.6185  TestNLL(z)=1.1745\n",
            "[07/20] GBDT best_iter=1407  MoE best_ep=5  TestRMSE(orig)=0.6411  TestNLL(z)=1.3395\n",
            "[08/20] GBDT best_iter=708  MoE best_ep=4  TestRMSE(orig)=0.6841  TestNLL(z)=1.3782\n",
            "[09/20] GBDT best_iter=617  MoE best_ep=5  TestRMSE(orig)=0.5763  TestNLL(z)=1.1883\n",
            "[10/20] GBDT best_iter=504  MoE best_ep=6  TestRMSE(orig)=0.6395  TestNLL(z)=1.3746\n",
            "[11/20] GBDT best_iter=305  MoE best_ep=4  TestRMSE(orig)=0.6652  TestNLL(z)=1.2604\n",
            "[12/20] GBDT best_iter=397  MoE best_ep=2  TestRMSE(orig)=0.6501  TestNLL(z)=1.2450\n",
            "[13/20] GBDT best_iter=142  MoE best_ep=2  TestRMSE(orig)=0.6421  TestNLL(z)=1.1934\n",
            "[14/20] GBDT best_iter=309  MoE best_ep=4  TestRMSE(orig)=0.6049  TestNLL(z)=1.1634\n",
            "[15/20] GBDT best_iter=571  MoE best_ep=2  TestRMSE(orig)=0.6077  TestNLL(z)=1.1074\n",
            "[16/20] GBDT best_iter=66  MoE best_ep=2  TestRMSE(orig)=0.5685  TestNLL(z)=1.0537\n",
            "[17/20] GBDT best_iter=654  MoE best_ep=5  TestRMSE(orig)=0.6127  TestNLL(z)=1.3863\n",
            "[18/20] GBDT best_iter=68  MoE best_ep=11  TestRMSE(orig)=0.6008  TestNLL(z)=1.0386\n",
            "[19/20] GBDT best_iter=61  MoE best_ep=10  TestRMSE(orig)=0.6186  TestNLL(z)=1.1290\n",
            "[20/20] GBDT best_iter=84  MoE best_ep=4  TestRMSE(orig)=0.6416  TestNLL(z)=1.1677\n",
            "\n",
            "== Anchor-MoE (coupled anchor feature) on Housing ==\n",
            "RMSE (orig) = 0.6224 ± 0.0089\n",
            "NLL  (z)    = 1.2031 ± 0.0244\n"
          ]
        }
      ],
      "source": [
        "\n",
        "X, y = load_dataset(\"wine\")\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = np.random.choice(np.arange(n), n, replace=False)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS = 1e-3, 400\n",
        "MEAN_MODE, DELTA_L2 = 'anchor+delta', 3e-3\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, m_it, moe_ep = train_one_split_with_anchor(\n",
        "        X[tr_idx], y[tr_idx], X[te_idx], y[te_idx],\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        MEAN_MODE=MEAN_MODE, DELTA_L2=DELTA_L2,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] GBDT best_iter={m_it}  MoE best_ep={moe_ep}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses, nlls = np.array(rmses, np.float64), np.array(nlls, np.float64)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== Anchor-MoE (coupled anchor feature) on Housing ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "95M6quno-sIQ",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "95M6quno-sIQ",
        "outputId": "855297a7-6fac-42f0-cb4f-cade14989784"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[WineRed] X.shape=(1599, 11) y.shape=(1599,) | y∈[3.000,8.000]\n",
            "== Dataset=Protein(local) X.shape=(1599, 11) y.shape=(1599,) ==\n",
            "[Protein] dataset size 1599 ≤ 1599, use full data\n",
            "[01/20] TestRMSE=0.5799  TestNLL(z)=2.7180\n",
            "[02/20] TestRMSE=0.5681  TestNLL(z)=4.0791\n",
            "[03/20] TestRMSE=0.7069  TestNLL(z)=6.5129\n",
            "[04/20] TestRMSE=0.5870  TestNLL(z)=2.7675\n",
            "[05/20] TestRMSE=0.5591  TestNLL(z)=2.2286\n",
            "[06/20] TestRMSE=0.6033  TestNLL(z)=8.1670\n",
            "[07/20] TestRMSE=0.6130  TestNLL(z)=2.0925\n",
            "[08/20] TestRMSE=0.6593  TestNLL(z)=4.7736\n",
            "[09/20] TestRMSE=0.5688  TestNLL(z)=2.7527\n",
            "[10/20] TestRMSE=0.6168  TestNLL(z)=5.0409\n",
            "[11/20] TestRMSE=0.6293  TestNLL(z)=10.5768\n",
            "[12/20] TestRMSE=0.6307  TestNLL(z)=2.4453\n",
            "[13/20] TestRMSE=0.6636  TestNLL(z)=6.7118\n",
            "[14/20] TestRMSE=0.5923  TestNLL(z)=4.7667\n",
            "[15/20] TestRMSE=0.5628  TestNLL(z)=3.7086\n",
            "[16/20] TestRMSE=0.5304  TestNLL(z)=3.3801\n",
            "[17/20] TestRMSE=0.6056  TestNLL(z)=2.4574\n",
            "[18/20] TestRMSE=0.5718  TestNLL(z)=9.8242\n",
            "[19/20] TestRMSE=0.5763  TestNLL(z)=5.0681\n",
            "[20/20] TestRMSE=0.5971  TestNLL(z)=9.1310\n",
            "\n",
            "== NGBoost (Normal) on Protein (subsampled to 10k) ==\n",
            "RMSE (orig) = 0.6011 ± 0.0093\n",
            "NLL  (z)    = 4.9601 ± 0.5988\n"
          ]
        }
      ],
      "source": [
        "import numpy as np, math, gc, os, sys, subprocess\n",
        "\n",
        "try:\n",
        "    from ngboost import NGBRegressor\n",
        "    from ngboost.distns import Normal\n",
        "    from ngboost.scores import MLE\n",
        "except Exception as e:\n",
        "    print(\"[setup] Installing ngboost ...\")\n",
        "    subprocess.check_call([sys._getframe(0).f_globals.get(\"sys\").executable, \"-m\", \"pip\", \"install\", \"ngboost\"])\n",
        "    from ngboost import NGBRegressor\n",
        "    from ngboost.distns import Normal\n",
        "    from ngboost.scores import MLE\n",
        "\n",
        "try:\n",
        "    X, y = load_dataset(\"wine\")  \n",
        "except NameError:\n",
        "    import pandas as pd\n",
        "    assert os.path.exists(\"protein.csv\"), \"缺少本地 protein.csv（第一列 y，其余为 X）\"\n",
        "    df = pd.read_csv(\"protein.csv\")\n",
        "    y = df.iloc[:, 0].to_numpy(np.float64)\n",
        "    X = df.iloc[:, 1:].to_numpy(np.float64)\n",
        "print(f\"== Dataset=Protein(local) X.shape={X.shape} y.shape={y.shape} ==\")\n",
        "\n",
        "SEED = 1\n",
        "RS = np.random.RandomState(SEED)\n",
        "N_SUB = 1599\n",
        "n_full = X.shape[0]\n",
        "if n_full > N_SUB:\n",
        "    idx_sub = RS.choice(n_full, N_SUB, replace=False)\n",
        "    X, y = X[idx_sub], y[idx_sub]\n",
        "    print(f\"[Protein] subsampled to {N_SUB} from {n_full}\")\n",
        "else:\n",
        "    print(f\"[Protein] dataset size {n_full} ≤ {N_SUB}, use full data\")\n",
        "\n",
        "RS = np.random.RandomState(SEED)  \n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for _ in range(20):\n",
        "    perm = RS.permutation(n)\n",
        "    splits.append((perm[: round(n*0.9)], perm[round(n*0.9):]))\n",
        "\n",
        "def rmse(a, b): return float(np.sqrt(np.mean((a - b)**2)))\n",
        "\n",
        "def zfit(y_tr):\n",
        "    my = float(y_tr.mean()); sy = float(y_tr.std() + 1e-8)\n",
        "    return my, sy\n",
        "\n",
        "LOG2PI = math.log(2*math.pi)\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for i, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    X_tr, y_tr = X[tr_idx], y[tr_idx]\n",
        "    X_te, y_te = X[te_idx], y[te_idx]\n",
        "\n",
        "    my, sy = zfit(y_tr)\n",
        "    y_te_z = (y_te - my) / sy\n",
        "\n",
        "    \n",
        "    ngb = NGBRegressor(\n",
        "        Dist=Normal,\n",
        "        Score=MLE,\n",
        "        n_estimators=2000,       \n",
        "        learning_rate=0.05,       \n",
        "        natural_gradient=True,\n",
        "        random_state=SEED + i,\n",
        "        verbose=False\n",
        "    )\n",
        "    ngb.fit(X_tr, y_tr)\n",
        "\n",
        "    dist = ngb.pred_dist(X_te)    \n",
        "    mu  = dist.loc                 \n",
        "    sig = dist.scale              \n",
        "    sig = np.maximum(sig, 1e-12)\n",
        "\n",
        "   \n",
        "    mu_z  = (mu - my) / sy\n",
        "    sig_z = sig / sy\n",
        "    sig2_z = np.maximum(sig_z**2, 1e-12)\n",
        "\n",
        "    \n",
        "    nll = 0.5*np.log(sig2_z) + 0.5*((y_te_z - mu_z)**2)/sig2_z + 0.5*LOG2PI\n",
        "    nll = float(np.mean(nll))\n",
        "\n",
        "    \n",
        "    rmse_val = rmse(mu, y_te)\n",
        "\n",
        "    print(f\"[{i:02d}/20] TestRMSE={rmse_val:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse_val); nlls.append(nll)\n",
        "    gc.collect()\n",
        "\n",
        "rmses = np.array(rmses, np.float64); nlls = np.array(nlls, np.float64)\n",
        "se = lambda a: a.std(ddof=1)/np.sqrt(len(a))\n",
        "print(\"\\n== NGBoost (Normal) on Protein (subsampled to 10k) ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "O0sN5Pu-RjPe",
      "metadata": {
        "id": "O0sN5Pu-RjPe"
      },
      "source": []
    },
    {
      "cell_type": "markdown",
      "id": "ROrysG5T6ykz",
      "metadata": {
        "id": "ROrysG5T6ykz"
      },
      "source": [
        "\n",
        "# **Yacht**\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Xz-BusDhmuMQ",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Xz-BusDhmuMQ",
        "outputId": "0f605330-eec4-442d-d025-12f752d8483d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Yacht] X.shape=(308, 6) y.shape=(308,) | y∈[0.010,62.420]\n",
            "[01/20] GBDT best_iter=1997  MoE best_ep=57  TestRMSE(orig)=0.7562  TestNLL(z)=-1.8476\n",
            "[02/20] GBDT best_iter=1956  MoE best_ep=278  TestRMSE(orig)=0.6752  TestNLL(z)=-1.9226\n",
            "[03/20] GBDT best_iter=1908  MoE best_ep=116  TestRMSE(orig)=0.3375  TestNLL(z)=-1.9768\n",
            "[04/20] GBDT best_iter=2000  MoE best_ep=91  TestRMSE(orig)=0.5994  TestNLL(z)=-1.8201\n",
            "[05/20] GBDT best_iter=1930  MoE best_ep=161  TestRMSE(orig)=0.6423  TestNLL(z)=-1.8032\n",
            "[06/20] GBDT best_iter=1999  MoE best_ep=71  TestRMSE(orig)=0.2992  TestNLL(z)=-1.9629\n",
            "[07/20] GBDT best_iter=1999  MoE best_ep=313  TestRMSE(orig)=0.2742  TestNLL(z)=-2.0070\n",
            "[08/20] GBDT best_iter=1998  MoE best_ep=236  TestRMSE(orig)=0.4175  TestNLL(z)=-1.6997\n",
            "[09/20] GBDT best_iter=1462  MoE best_ep=395  TestRMSE(orig)=0.6515  TestNLL(z)=-1.9014\n",
            "[10/20] GBDT best_iter=1983  MoE best_ep=267  TestRMSE(orig)=0.3339  TestNLL(z)=-2.0049\n",
            "[11/20] GBDT best_iter=1940  MoE best_ep=376  TestRMSE(orig)=0.4289  TestNLL(z)=-1.8354\n",
            "[12/20] GBDT best_iter=1703  MoE best_ep=169  TestRMSE(orig)=0.8642  TestNLL(z)=-1.6447\n",
            "[13/20] GBDT best_iter=1998  MoE best_ep=203  TestRMSE(orig)=1.2122  TestNLL(z)=-1.6999\n",
            "[14/20] GBDT best_iter=303  MoE best_ep=50  TestRMSE(orig)=0.6242  TestNLL(z)=-1.7765\n",
            "[15/20] GBDT best_iter=237  MoE best_ep=242  TestRMSE(orig)=1.1062  TestNLL(z)=-1.2076\n",
            "[16/20] GBDT best_iter=1431  MoE best_ep=245  TestRMSE(orig)=0.4566  TestNLL(z)=-1.9347\n",
            "[17/20] GBDT best_iter=989  MoE best_ep=390  TestRMSE(orig)=0.5604  TestNLL(z)=-1.8974\n",
            "[18/20] GBDT best_iter=762  MoE best_ep=251  TestRMSE(orig)=0.7232  TestNLL(z)=-1.7651\n",
            "[19/20] GBDT best_iter=1792  MoE best_ep=107  TestRMSE(orig)=0.5363  TestNLL(z)=-1.8633\n",
            "[20/20] GBDT best_iter=1999  MoE best_ep=109  TestRMSE(orig)=0.9167  TestNLL(z)=-1.4953\n",
            "\n",
            "== Anchor-MoE (coupled anchor feature) on Housing ==\n",
            "RMSE (orig) = 0.6208 ± 0.0578\n",
            "NLL  (z)    = -1.8033 ± 0.0427\n"
          ]
        }
      ],
      "source": [
        "\n",
        "X, y = load_dataset(\"yacht\")\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED); torch.manual_seed(SEED)\n",
        "if torch.cuda.is_available():\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "n = X.shape[0]\n",
        "splits = []\n",
        "for i in range(20):\n",
        "    perm = np.random.choice(np.arange(n), n, replace=False)\n",
        "    splits.append((perm[: round(n * 0.9)], perm[round(n * 0.9):]))\n",
        "\n",
        "STANDARDIZE_X = True\n",
        "D, K, HID, NC = 2, 8, 128, 3\n",
        "LR, EPOCHS = 1e-3, 400\n",
        "MEAN_MODE, DELTA_L2 = 'anchor+delta', 3e-3\n",
        "SIGMA_MIN, SIGMA_MAX = 5e-2, 1.0\n",
        "TOPK, SMOOTH_EPS = 2, 0.05\n",
        "\n",
        "rmses, nlls = [], []\n",
        "for itr, (tr_idx, te_idx) in enumerate(splits, 1):\n",
        "    rmse, nll, m_it, moe_ep = train_one_split_with_anchor(\n",
        "        X[tr_idx], y[tr_idx], X[te_idx], y[te_idx],\n",
        "        standardize_x=STANDARDIZE_X,\n",
        "        D=D, K=K, HID=HID, NC=NC,\n",
        "        LR=LR, EPOCHS=EPOCHS,\n",
        "        MEAN_MODE=MEAN_MODE, DELTA_L2=DELTA_L2,\n",
        "        SIGMA_MIN=SIGMA_MIN, SIGMA_MAX=SIGMA_MAX,\n",
        "        TOPK=TOPK, SMOOTH_EPS=SMOOTH_EPS\n",
        "    )\n",
        "    print(f\"[{itr:02d}/20] GBDT best_iter={m_it}  MoE best_ep={moe_ep}  \"\n",
        "          f\"TestRMSE(orig)={rmse:.4f}  TestNLL(z)={nll:.4f}\")\n",
        "    rmses.append(rmse); nlls.append(nll)\n",
        "    gc.collect()\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.empty_cache()\n",
        "\n",
        "rmses, nlls = np.array(rmses, np.float64), np.array(nlls, np.float64)\n",
        "se = lambda a: a.std(ddof=1) / np.sqrt(len(a))\n",
        "print(\"\\n== Anchor-MoE (coupled anchor feature) on Housing ==\")\n",
        "print(f\"RMSE (orig) = {rmses.mean():.4f} ± {se(rmses):.4f}\")\n",
        "print(f\"NLL  (z)    = {nlls.mean():.4f} ± {se(nlls):.4f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "v4RLpyjYGfBw",
      "metadata": {
        "id": "v4RLpyjYGfBw"
      },
      "source": [
        "# **LVMI**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ncVJ2l9BGkyx",
      "metadata": {
        "id": "ncVJ2l9BGkyx"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "gpuType": "V6E1",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
