{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import shap\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
    "from sklearn.tree import plot_tree,DecisionTreeRegressor, plot_tree\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.model_selection import train_test_split\n",
    "from scipy.stats import spearmanr\n",
    "import matplotlib\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import os\n",
    "import math, itertools, collections, warnings, random, os, json, time\n",
    "from typing import Dict, List, Tuple\n",
    "from sklearn.metrics import adjusted_rand_score as ARI, normalized_mutual_info_score as NMI\n",
    "\n",
    "from sampling_shap_gpu import sampling_shap\n",
    "from path_margin_soft import FrozenDepthTree, extract_paths, path_margin_softlabel\n",
    "# Set a specific seed value\n",
    "your_seed = 42\n",
    "\n",
    "\n",
    "def set_global_seed(seed: int = 42) -> None:\n",
    "    # ---------- Python & NumPy ----------\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "\n",
    "    # ---------- PyTorch ----------\n",
    "    torch.manual_seed(seed)           # CPU\n",
    "    torch.cuda.manual_seed_all(seed) \n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "set_global_seed(your_seed)\n",
    "matplotlib.rcParams['font.family'] = 'DejaVu Serif'\n",
    "device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ylT7L9h2eFxd",
    "outputId": "fdd80e10-f7a2-4ba9-a52d-2dbada39bec2"
   },
   "outputs": [],
   "source": [
    "\n",
    "class TreeDropout(nn.Module):\n",
    "    def __init__(self, p=0.5, keep_root=True):\n",
    "        super().__init__()\n",
    "        self.p, self.keep_root = p, keep_root\n",
    "\n",
    "    def forward(self, x):\n",
    "        if not self.training or self.p == 0.0:\n",
    "            return x\n",
    "        mask = torch.bernoulli(torch.full_like(x, 1 - self.p)) / (1 - self.p)\n",
    "        if self.keep_root:\n",
    "            mask[:, 0] = 1.0\n",
    "        return x * mask\n",
    "\n",
    "class VAE(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, latent_dim, num_clusters):\n",
    "        super(VAE, self).__init__()\n",
    "        self.fc1 = nn.Linear(input_dim, hidden_dim)\n",
    "        self.fc21 = nn.Linear(hidden_dim, latent_dim)\n",
    "        self.fc22 = nn.Linear(hidden_dim, latent_dim)\n",
    "        self.fc3 = nn.Linear(latent_dim, hidden_dim)\n",
    "        self.fc4 = nn.Linear(hidden_dim, input_dim)\n",
    "        self.main = nn.Sequential(\n",
    "            nn.Linear(latent_dim, latent_dim // 2),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(p=0.2),\n",
    "            nn.Linear(latent_dim // 2, 2)\n",
    "        )\n",
    "        self.cluster_layer = nn.Linear(latent_dim, num_clusters)\n",
    "        self.tree_dp = nn.Dropout(p=0.10)\n",
    "\n",
    "        # Initialize weights with Xavier initialization for stability\n",
    "        self._initialize_weights()\n",
    "\n",
    "    def _initialize_weights(self):\n",
    "        def init(m):\n",
    "            if isinstance(m, nn.Linear):\n",
    "                nn.init.xavier_uniform_(m.weight)\n",
    "                nn.init.zeros_(m.bias)\n",
    "        self.apply(init)\n",
    "\n",
    "    def encode(self, x):\n",
    "        h1 = F.relu(self.fc1(x))\n",
    "        return self.fc21(h1), self.fc22(h1)\n",
    "\n",
    "    def reparameterize(self, mu, logvar):\n",
    "        std = torch.exp(0.5 * logvar)\n",
    "        eps = torch.randn_like(std)\n",
    "        return mu + eps * std\n",
    "\n",
    "    def decode(self, z):\n",
    "        h3 = F.relu(self.fc3(z))\n",
    "        return self.fc4(h3)\n",
    "\n",
    "    def forward(self, x):\n",
    "        mu, logvar = self.encode(x.reshape(-1, input_dim))\n",
    "\n",
    "        # Clamp mu and logvar to prevent extreme values\n",
    "        mu = torch.clamp(mu, -10.0, 10.0)\n",
    "        logvar = torch.clamp(logvar, -10.0, 10.0)\n",
    "\n",
    "        z = self.reparameterize(mu, logvar)\n",
    "\n",
    "        leaf_logits = self.cluster_layer(z)\n",
    "        \n",
    "        if (self.training                      # only when model.train()\n",
    "            and TREE_DROPOUT_P > 0\n",
    "            and epoch >= TREE_DP_WARM_EPS):\n",
    "            leaf_logits = self.tree_dp(leaf_logits)\n",
    "            \n",
    "        return self.decode(z), mu, logvar, z, self.main(z), leaf_logits\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded matrix  genes × samples  :  (60488, 424)\n",
      "Class distribution  (0 = normal, 1 = tumour/recurrent) :\n",
      "(array([0, 1]), array([ 50, 374]))\n",
      "   Final matrix shape  : (424, 60488)\n",
      "[ 44 295]\n",
      "Training data shape: (339, 60483)\n",
      "Training labels shape: (339,)\n",
      "Test data shape: (85, 60483)\n",
      "Test labels shape: (85,)\n"
     ]
    }
   ],
   "source": [
    "RAW_FILE = './content/TCGA-LIHC.htseq_counts.tsv' \n",
    "\n",
    "df = pd.read_csv(RAW_FILE, sep='\\t', index_col=0)\n",
    "print(f'Loaded matrix  genes × samples  :  {df.shape}')\n",
    "\n",
    "sample_ids   = np.array(df.columns)\n",
    "sample_codes = np.array([sid[13:15] for sid in sample_ids])\n",
    "\n",
    "keep_mask    = np.isin(sample_codes, ['01', '02', '11'])\n",
    "df           = df.loc[:, keep_mask]\n",
    "sample_ids   = sample_ids[keep_mask]\n",
    "sample_codes = sample_codes[keep_mask]\n",
    "\n",
    "y = np.where(np.isin(sample_codes, ['01', '02']), 1, 0)\n",
    "\n",
    "df.index = df.index.str.split('.').str[0]\n",
    "df_log = np.log2(df + 1)\n",
    "\n",
    "X_df = df_log.T                               \n",
    "X    = X_df.values                              \n",
    "\n",
    "X_df['label'] = y\n",
    "\n",
    "features = X_df.iloc[:, :-6]\n",
    "labels   = X_df.iloc[:, -1].replace(-1, 0)\n",
    "\n",
    "train_data, test_data, train_labels, test_labels = train_test_split(\n",
    "    features, labels, test_size=0.2, random_state=42)\n",
    "\n",
    "scaler = StandardScaler().fit(train_data)\n",
    "train_data = scaler.transform(train_data)\n",
    "test_data = scaler.transform(test_data)\n",
    "\n",
    "from torch.utils.data import WeightedRandomSampler\n",
    "class_sample_counts = np.bincount(train_labels) \n",
    "\n",
    "weights = (1. / class_sample_counts) * 1000\n",
    "sample_weights = weights[train_labels] \n",
    "sampler = WeightedRandomSampler(\n",
    "    weights=sample_weights,\n",
    "    num_samples=int(len(train_labels)),\n",
    "    replacement=True\n",
    ")\n",
    "\n",
    "print(\"Training data shape:\", train_data.shape)\n",
    "print(\"Training labels shape:\", train_labels.shape)\n",
    "print(\"Test data shape:\", test_data.shape)\n",
    "print(\"Test labels shape:\", test_labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomDataset(Dataset):\n",
    "    def __init__(self, data, labels):\n",
    "        self.data = data\n",
    "        self.labels = labels\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.data[idx], self.labels[idx]\n",
    "\n",
    "        \n",
    "train_data_tensor = torch.tensor(train_data, dtype=torch.float32).to('cuda')\n",
    "test_data_tensor = torch.tensor(test_data, dtype=torch.float32).to('cuda')\n",
    "train_labels_tensor = torch.tensor(train_labels.values, dtype=torch.long).to('cuda')\n",
    "test_labels_tensor = torch.tensor(test_labels.values, dtype=torch.long).to('cuda')\n",
    "\n",
    "# Dataloader\n",
    "train_loader = DataLoader(CustomDataset(train_data_tensor, train_labels_tensor), batch_size=12, sampler=sampler)\n",
    "test_loader = DataLoader(CustomDataset(test_data_tensor, test_labels_tensor), batch_size=12, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "batch_size = 12\n",
    "input_dim = train_data.shape[1]  # 260\n",
    "hidden_dim = 64\n",
    "latent_dim = 64\n",
    "# Adjust batch size for training to match the number of training samples\n",
    "num_clusters = 8  # Define the number of clusters\n",
    "\n",
    "\n",
    "model = VAE(input_dim, hidden_dim, latent_dim, num_clusters).to('cuda')\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
    "cri = torch.nn.CrossEntropyLoss(label_smoothing=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "baseline = np.zeros(input_dim, dtype=np.float32) \n",
    "\n",
    "num_explain = 32\n",
    "x_explain = train_data_tensor[:num_explain].cpu().numpy()\n",
    "\n",
    "subset_bank = None                             \n",
    "\n",
    "\n",
    "shap_log_raw, shap_log_ema = [], []\n",
    "ema = None\n",
    "def model_np(x_np_batch: torch.Tensor) -> torch.Tensor:\n",
    "    with torch.no_grad():\n",
    "        if x_np_batch.dim() == 1:\n",
    "            x_np_batch = x_np_batch.unsqueeze(0)              \n",
    "        logits = model(x_np_batch)[4]               \n",
    "\n",
    "        prob_pos = torch.softmax(logits, dim=1)[:, 1]\n",
    "        return prob_pos.squeeze(-1) \n",
    "\n",
    "\n",
    "def get_weights(ep, T=50):\n",
    "    t = ep / T\n",
    "    beta_kl   = 0.2 * 0.5 * (1 - math.cos(math.pi * min(t, .3)))\n",
    "    lam_c     = 8.0 * 0.5 * (1 - math.cos(math.pi * t))\n",
    "    lam_ce    = 0.2 + 0.8 * t\n",
    "    return beta_kl, lam_c, lam_ce\n",
    "    \n",
    "POOL_IDX = np.random.choice(len(train_data_tensor), \n",
    "                             int(0.05 * len(train_data_tensor)),\n",
    "                             replace=False)\n",
    "def shap_batch(B=16, seed=42):\n",
    "    rng = np.random.default_rng(seed)\n",
    "    return rng.choice(POOL_IDX, B, replace=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np, torch, torch.nn.functional as F\n",
    "from collections import deque\n",
    "from sklearn.tree import DecisionTreeRegressor\n",
    "from typing import List, Tuple\n",
    "# ------------------------------\n",
    "# 可调超参\n",
    "# ------------------------------\n",
    "epochs            = 200\n",
    "ACC_GATE          = 0.58    \n",
    "CLF_PATIENCE      = 5        \n",
    "WARM_SHAP_WINDOW  = max(CLF_PATIENCE, 5)  \n",
    "SHAP_B            = 32       \n",
    "\n",
    "k_feat            = int(input_dim * 0.25)        \n",
    "TREE_EVERY        = 4    \n",
    "UPDATE_EVERY      = 2         \n",
    "STAB_THRES        = 0.92     \n",
    "TREE_PATIENCE     = TREE_EVERY + 1  \n",
    "\n",
    "MAX_TREE_DEPTH    = 10        \n",
    "MIN_LEAF_FRAC     = 0.005    \n",
    "LAM_C_BOOST_EPOCH = 30       \n",
    "\n",
    "BETA_FAST = 0.85\n",
    "BETA_SLOW = 0.6\n",
    "\n",
    "# Tree-regularization surrogate params\n",
    "LAM_TREE      = 4            \n",
    "SURR_REFRESH_EVERY = 50      \n",
    "INIT_TRAIN_SIZE = 25  \n",
    "INCREMENT_K = 25         \n",
    "MAX_SNAP_DIM  = 50000        \n",
    "SURR_L2 = 1e-4\n",
    "HIST_LEN       = 100         \n",
    "\n",
    "# Tree dropout\n",
    "TREE_DROPOUT_P   = 0.10      \n",
    "TREE_DP_WARM_EPS = 5        \n",
    "\n",
    "stage         = 'clf'        \n",
    "frozen        = False         \n",
    "lam_c_on      = False        \n",
    "\n",
    "ema_fast      = np.zeros(input_dim, dtype=np.float32)\n",
    "ema_slow      = np.zeros_like(ema_fast)\n",
    "\n",
    "\n",
    "shap_pool     = deque(maxlen=WARM_SHAP_WINDOW) \n",
    "U_star        = None         \n",
    "T_star        = None\n",
    "tree_buffer   = deque(maxlen=3)\n",
    "dt_model      = None        \n",
    "last_tree_ep  = -999        \n",
    "last_ari      = 0.0       \n",
    "tree_no_upd   = 0            \n",
    "n_sampling = 2\n",
    "\n",
    "no_accept_streak = 0\n",
    "snapshots = deque(maxlen=HIST_LEN)\n",
    "surrogate_bundle = {'model': None, 'opt': None}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "def average_path_length(tree, X: np.ndarray) -> float:\n",
    "    \"\"\"Train a CART on (X,y_hat) and return mean path length.\"\"\"\n",
    "    path = tree.decision_path(X)\n",
    "    n_nodes = np.diff(path.indptr)    # samples → node counts\n",
    "    return float((n_nodes - 1).mean())\n",
    "\n",
    "def live_flatten_params(model):\n",
    "    v = torch.cat([p.view(-1) for p in model.parameters()])\n",
    "    return v\n",
    "def snapshot_params(model: nn.Module) -> torch.Tensor:\n",
    "    with torch.no_grad():\n",
    "        v = torch.cat([p.detach().flatten().cpu() for p in model.parameters()])\n",
    "    return v\n",
    "    \n",
    "class APLSurrogate(nn.Module):\n",
    "    def __init__(self, d_in):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(d_in, 512), nn.ReLU(),\n",
    "            nn.Linear(512, 256), nn.ReLU(),\n",
    "            nn.Linear(256, 1),\n",
    "            nn.Softplus()\n",
    "        )\n",
    "    def forward(self, w):\n",
    "        return self.net(w).squeeze(-1)\n",
    "        \n",
    "def train_surrogate(surrogate: APLSurrogate,\n",
    "                    opt_surr: torch.optim.Optimizer,\n",
    "                    snapshots: List[Tuple[torch.Tensor, float]],\n",
    "                    device: torch.device,\n",
    "                    epochs: int = 20,\n",
    "                    batch_size: int = 32) -> None:\n",
    "    \"\"\"\n",
    "    snapshots: [(w_vec_cpu, apl_float), ...] \n",
    "    \"\"\"\n",
    "    global since_last_update\n",
    "    W = torch.stack([w for (w, _) in snapshots])          # (N,D) CPU\n",
    "    A = torch.tensor([a for (_, a) in snapshots], dtype=torch.float32)\n",
    "    dataset = torch.utils.data.TensorDataset(W, A)\n",
    "    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "    surrogate.train()\n",
    "    for _ in range(epochs):\n",
    "        for wb, ab in loader:\n",
    "            wb, ab = wb.to(device), ab.to(device)\n",
    "            pred = surrogate(wb)\n",
    "            loss_s = F.mse_loss(pred, ab)\n",
    "            opt_surr.zero_grad()\n",
    "            loss_s.backward()\n",
    "            opt_surr.step()\n",
    "    since_last_update = 0\n",
    "            \n",
    "def update_surrogate(surrogate_bundle, snapshots, device):\n",
    "    if len(snapshots) < INIT_TRAIN_SIZE:\n",
    "        return\n",
    "    if surrogate_bundle['model'] is None:\n",
    "        d_in = snapshots[0][0].numel()\n",
    "        surrogate_bundle['model'] = APLSurrogate(d_in).to(device)\n",
    "        surrogate_bundle['opt'] = torch.optim.Adam(\n",
    "            surrogate_bundle['model'].parameters(), 1e-4, weight_decay=SURR_L2)\n",
    "\n",
    "    W = torch.stack([w for (w, _) in snapshots])\n",
    "    A = torch.tensor([a for (_, a) in snapshots], dtype=torch.float32)\n",
    "    train_surrogate(surrogate_bundle['model'],\n",
    "                    surrogate_bundle['opt'],\n",
    "                    list(snapshots), device,\n",
    "                    epochs=100, batch_size=16)\n",
    "    \n",
    "def refresh_tree(model, X_sub, y_target, snapshots, device, seed):\n",
    "    min_leaf = max(1, int(MIN_LEAF_FRAC * len(X_U)) + 10)\n",
    "    dt, apl_true, val_mse = fit_tree(\n",
    "        X_sub, y_target,\n",
    "        min_samples_leaf=min_leaf,\n",
    "        seed=seed,\n",
    "        lambda_apl=1e-4\n",
    "    )             \n",
    "    # apl_true = average_path_length(dt, X_sub)\n",
    "    w_snapshot = snapshot_params(model)\n",
    "    return dt, apl_true,w_snapshot\n",
    "\n",
    "def tree_regularization_loss(model: nn.Module,\n",
    "                             surrogate_bundle: dict,\n",
    "                             device: torch.device) -> torch.Tensor:\n",
    "    surrogate = surrogate_bundle.get('model', None)\n",
    "    if surrogate is None:\n",
    "        return torch.zeros((), device=device)\n",
    "    surrogate.eval()\n",
    "    w_live = live_flatten_params(model).unsqueeze(0).to(device)\n",
    "    with torch.set_grad_enabled(True):\n",
    "        omega_hat = surrogate(w_live).squeeze(0)  # 标量（可反传）\n",
    "    return omega_hat\n",
    "\n",
    "import time\n",
    "def collect_shap_rank(epoch, B=SHAP_B):\n",
    "    idxs = np.random.choice(len(train_data_tensor), B, replace=False)\n",
    "\n",
    "    phi_acc = torch.zeros(input_dim, device=device)\n",
    "    for i in idxs:\n",
    "        x_i = train_data_tensor[i].to(device).float()             # (D,)\n",
    "        baseline = torch.zeros_like(x_i, device=device)           # (D,)\n",
    "        phi_i, _ = sampling_shap(\n",
    "            model_np,\n",
    "            x_i,\n",
    "            baseline=baseline,\n",
    "            k=10,\n",
    "            seed= 42,\n",
    "            reg_alpha=5e-3\n",
    "        )  \n",
    "        phi_acc += phi_i.abs()\n",
    "    phi_mean_t = phi_acc / float(B)          \n",
    "    phi_mean = phi_mean_t.detach().cpu().numpy()\n",
    "    return np.argsort(-phi_mean), phi_mean\n",
    "\n",
    "\n",
    "def fit_tree(\n",
    "    X, y,\n",
    "    min_samples_leaf=25,\n",
    "    seed=42,\n",
    "    val_ratio=0.2,\n",
    "    lambda_apl=1e-4,     \n",
    "    max_alphas=30         \n",
    "):\n",
    "    X_tr, X_val, y_tr, y_val = train_test_split(\n",
    "        X, y, test_size=val_ratio, random_state=seed\n",
    "    )\n",
    "\n",
    "    base_tree = DecisionTreeRegressor(\n",
    "        min_samples_leaf=min_samples_leaf,\n",
    "        random_state=seed\n",
    "    ).fit(X_tr, y_tr)\n",
    "\n",
    "    path = base_tree.cost_complexity_pruning_path(X_tr, y_tr)\n",
    "    ccp_alphas = path.ccp_alphas[:-1]\n",
    "    if len(ccp_alphas) == 0:\n",
    "        apl = average_path_length(base_tree, X)\n",
    "        pred_val = base_tree.predict(X_val)\n",
    "        mse = float(((pred_val - y_val)**2).mean())\n",
    "        return base_tree, float(apl), float(mse)\n",
    "\n",
    "    if len(ccp_alphas) > max_alphas:\n",
    "        idx = np.linspace(0, len(ccp_alphas)-1, max_alphas).astype(int)\n",
    "        ccp_alphas = ccp_alphas[idx]\n",
    "\n",
    "    best_tree = base_tree\n",
    "    best_score = np.inf\n",
    "    best_apl, best_mse = None, None\n",
    "    \n",
    "    pred_val = best_tree.predict(X_val)\n",
    "    best_mse = np.mean((pred_val - y_val) ** 2)    \n",
    "    best_apl = average_path_length(best_tree, X)\n",
    "    # for alpha in ccp_alphas:\n",
    "    #     t = DecisionTreeRegressor(\n",
    "    #         min_samples_leaf=min_samples_leaf,\n",
    "    #         random_state=seed,\n",
    "    #         ccp_alpha=alpha\n",
    "    #     ).fit(X_tr, y_tr)\n",
    "\n",
    "    #     pred_val = t.predict(X_val)\n",
    "    #     mse = np.mean((pred_val - y_val) ** 2)\n",
    "\n",
    "    #     apl = average_path_length(t, X)\n",
    "    #     if apl <= 2: continue\n",
    "    #     score = mse + lambda_apl * apl\n",
    "    #     if score < best_score:\n",
    "    #         best_score = score\n",
    "    #         best_tree  = t\n",
    "    #         best_apl   = apl\n",
    "    #         best_mse   = mse\n",
    "\n",
    "    return best_tree,float(best_apl), float(best_mse)\n",
    "\n",
    "def equal_token(tok_a, tok_b, eps=1e-3):\n",
    "    fi1, th1, dir1 = tok_a\n",
    "    fi2, th2, dir2 = tok_b\n",
    "    return (fi1 == fi2) and (dir1 == dir2) and (abs(th1 - th2) <= eps)\n",
    "    \n",
    "def path_distance_tol(p1, p2, eps=1e-3):\n",
    "    k = 0\n",
    "    for t1, t2 in zip(p1, p2):\n",
    "        if equal_token(t1, t2, eps=eps):\n",
    "            k += 1\n",
    "        else:\n",
    "            break\n",
    "    return (len(p1) - k) + (len(p2) - k)\n",
    "\n",
    "def path_similarity(p1, p2, beta=1.0, eps=1e-3):\n",
    "    dist = path_distance_tol(p1, p2, eps=eps)\n",
    "    return 1.0 / (1.0 + beta * dist)\n",
    "\n",
    "def _rank_corr(a, b, eps=1e-12):\n",
    "    ar = a.argsort().argsort().astype(np.float32)\n",
    "    br = b.argsort().argsort().astype(np.float32)\n",
    "    ar = (ar - ar.mean()) / (ar.std() + eps)\n",
    "    br = (br - br.mean()) / (br.std() + eps)\n",
    "    return float((ar * br).mean())\n",
    "\n",
    "\n",
    "def score_tree_value(new_tree, old_tree, \n",
    "                     X_all, y_target,             \n",
    "                     U_star, ema_slow,          \n",
    "                     epoch,\n",
    "                     seed,\n",
    "                     lambda_apl=7.5e-4,          \n",
    "                     w_stab=0.20,                 \n",
    "                     w_align=0.1,                 \n",
    "                     w_jump=0.05,                 \n",
    "                     accept_margin=1e-4):         \n",
    "    import os\n",
    "    import numpy as np\n",
    "    from sklearn.model_selection import train_test_split\n",
    "    from sklearn.metrics import normalized_mutual_info_score as NMI\n",
    "\n",
    "    total_steps = 300\n",
    "    warmup_frac   = 0.25\n",
    "    anneal_frac   = 0.50\n",
    "    nmi_start     = 0.30\n",
    "    nmi_end       = 0.70\n",
    "    align_min     = 0.35\n",
    "    big_gain_rel  = 0.02\n",
    "    big_gain_nmi_floor = 0.3\n",
    "\n",
    "    warmup_steps = int(total_steps * warmup_frac)\n",
    "    anneal_steps = int(total_steps * anneal_frac)\n",
    "    if epoch < warmup_steps:\n",
    "        nmi_min_now = 0.0\n",
    "    else:\n",
    "        if anneal_steps <= 0:\n",
    "            nmi_min_now = nmi_end\n",
    "        else:\n",
    "            t = min(max(epoch - warmup_steps, 0), anneal_steps) / float(anneal_steps)\n",
    "            nmi_min_now = nmi_start + (nmi_end - nmi_start) * t\n",
    "\n",
    "    X_tr, X_val, y_tr, y_val = train_test_split(X_all, y_target, test_size=0.2, random_state=seed)\n",
    "\n",
    "\n",
    "    def _mse_apl(t):\n",
    "        pred = t.predict(X_val)\n",
    "        mse  = float(((pred - y_val)**2).mean())\n",
    "        apl  = average_path_length(t, X_all)  \n",
    "        return mse, apl\n",
    "\n",
    "    mse_new, apl_new = _mse_apl(new_tree)\n",
    "    mse_old, apl_old = _mse_apl(old_tree)\n",
    "\n",
    "    delta_mse = mse_old - mse_new\n",
    "    rel_gain = delta_mse / (mse_old + 1e-12)\n",
    "\n",
    "    leaf_new = new_tree.apply(X_all)\n",
    "    leaf_old = old_tree.apply(X_all)\n",
    "    nmi = NMI(leaf_old, leaf_new)\n",
    "\n",
    "    imp = getattr(new_tree, 'feature_importances_', None)\n",
    "    if imp is None or len(imp) != len(U_star):\n",
    "        align = 0.0\n",
    "    else:\n",
    "        align = _rank_corr(np.asarray(imp), np.asarray(ema_slow[U_star]))\n",
    "\n",
    "    n_new = len(np.unique(leaf_new)); n_old = len(np.unique(leaf_old))\n",
    "    jump_ratio = max(0, abs(n_new - n_old) / (n_old + 1e-8))\n",
    "\n",
    "    J_new = mse_new + lambda_apl * apl_new\n",
    "    J_old = mse_old + lambda_apl * apl_old\n",
    "    base = 5e-3 if epoch < 100 else 2e-3 if epoch < 200 else 1e-3\n",
    "    accept_margin = max(1e-4, base * J_old)\n",
    "\n",
    "    epsilon = accept_margin\n",
    "    improved = (J_new + epsilon) < J_old\n",
    "\n",
    "    align_ok = (align >= align_min)\n",
    "    nmi_ok   = (nmi >= nmi_min_now)\n",
    "\n",
    "    if rel_gain >= big_gain_rel:\n",
    "        nmi_ok = (nmi >= big_gain_nmi_floor)\n",
    "\n",
    "    accept = bool(improved and align_ok and nmi_ok)\n",
    "\n",
    "    L_new = J_new\n",
    "    L_old = J_old\n",
    "    log = (f\"[gate] new_vs_old | \"\n",
    "           f\"MSE {mse_new:.4f}/{mse_old:.4f}  \"\n",
    "           f\"APL {apl_new:.2f}/{apl_old:.2f}  \"\n",
    "           f\"NMI {nmi:.3f} (τ={nmi_min_now:.2f})  \"\n",
    "           f\"align {align:.3f} (τ={align_min:.2f})  \"\n",
    "           f\"leaves {n_old}->{n_new}  \"\n",
    "           f\"J {J_new:.5f}/{J_old:.5f}  \"\n",
    "           f\"Δrel {rel_gain:.3f}  \"\n",
    "           f\"ε {epsilon:.2e}  \"\n",
    "           f\"accept={accept}\")\n",
    "\n",
    "    return accept, dict(\n",
    "        mse_new=mse_new, mse_old=mse_old,\n",
    "        apl_new=apl_new, apl_old=apl_old,\n",
    "        nmi=nmi, align=align,\n",
    "        n_old=n_old, n_new=n_new,\n",
    "        L_new=L_new, L_old=L_old, \n",
    "        log=log\n",
    "    )\n",
    "\n",
    "def pick_topk_by_coverage(score_vec, cover=0.90, kmin=16, kmax=128):\n",
    "    s = np.asarray(score_vec, dtype=np.float32)\n",
    "    s = np.abs(s)                    \n",
    "    k = min(kmax, int(0.25 * input_dim))\n",
    "    rank = np.argsort(-s)        \n",
    "    U_star = rank[:k]\n",
    "    return U_star, k\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for epoch in range(epochs):\n",
    "    model.train()\n",
    "    tot_loss = BCE = KLD = CL = CE = TREE = 0\n",
    "    beta_kl, lam_c, lam_ce = get_weights(epoch)\n",
    "    if epoch == 50:        \n",
    "        LAM_TREE *= 2\n",
    "    for xb, yb in train_loader:\n",
    "        xb, yb = xb.cuda(), yb.cuda()\n",
    "        recon, mu, logv, z, logits, cl_out = model(xb)   \n",
    "        # ---- VAE loss ----\n",
    "        bce = F.mse_loss(recon, xb, reduction='mean')\n",
    "        kld = -0.5*torch.sum(1 + logv - mu.pow(2) - logv.exp()) / xb.size(0)\n",
    "        if (U_star is None) or (not lam_c_on):                             # warm-up 前5轮 λ_c=0\n",
    "            closs = torch.tensor(0., device=xb.device)\n",
    "        else:\n",
    "            soft_np = path_margin_softlabel(dt_model, xb[:, U_star].cpu().numpy(), tau = 5 * 0.5 * (1 + np.cos(np.pi * epoch / epochs)))\n",
    "            closs = F.kl_div(F.log_softmax(cl_out,1), torch.tensor(soft_np, device=xb.device), reduction='batchmean')\n",
    "            \n",
    "        celoss   = cri(logits, yb)\n",
    "\n",
    "\n",
    "        # ---- 树正则 (APL surrogate) ----   # <<< NEW\n",
    "        if surrogate_bundle['model'] is not None:\n",
    "            tree_reg = tree_regularization_loss(model.main, surrogate_bundle,\n",
    "                                            device)\n",
    "        else:\n",
    "            tree_reg    = torch.zeros((), device=xb.device)\n",
    "            \n",
    "        vae_loss = bce + beta_kl*kld + lam_c*closs\n",
    "        loss     = vae_loss + lam_ce*celoss + 0 * tree_reg\n",
    "\n",
    "\n",
    "        optimizer.zero_grad(); loss.backward(); optimizer.step()\n",
    "\n",
    "    if stage == 'clf':\n",
    "        shap_rank, _ = collect_shap_rank(epoch, B=SHAP_B)\n",
    "        shap_pool.append(shap_rank)\n",
    "\n",
    "        if len(shap_pool) >=  WARM_SHAP_WINDOW:\n",
    "            vote = np.zeros(input_dim, dtype=float)\n",
    "            weights = np.linspace(1.0, 0.5, num=len(shap_pool))\n",
    "            for r, w in zip(shap_pool, weights):\n",
    "                vote[r[:k_feat]] += w\n",
    "            U_star, k_feat = pick_topk_by_coverage(vote, cover=0.90, kmin=16, kmax=128)\n",
    "            X_U = train_data_tensor[:, U_star].cpu().numpy()\n",
    "            y_target = model_np(train_data_tensor.to(device)).detach().cpu().numpy()  # full → prob\n",
    "\n",
    "            dt_model, apl_true, w_snap = refresh_tree(\n",
    "                model.main,\n",
    "                X_U,\n",
    "                y_target,\n",
    "                snapshots,\n",
    "                device=next(model.parameters()).device,\n",
    "                seed=42,\n",
    "            )\n",
    "            snapshots.append((w_snap, apl_true))\n",
    "            tree_buffer.append(dt_model)\n",
    "            update_surrogate(surrogate_bundle, snapshots, device)\n",
    "            n_leaf = len(np.unique(dt_model.apply(X_U)))\n",
    "            model.cluster_layer = torch.nn.Linear(latent_dim, n_leaf).cuda()\n",
    "\n",
    "            freq = np.zeros(input_dim, dtype=np.float32)\n",
    "            for r in shap_pool:\n",
    "                freq[r[:k_feat]] += 1.0\n",
    "            if freq.max() > 0:\n",
    "                phi_init = freq / freq.max()\n",
    "            else:\n",
    "                phi_init = freq\n",
    "            ema_slow[:] = phi_init\n",
    "            \n",
    "            \n",
    "            lam_c_on = True         \n",
    "            stage    = 'tree'\n",
    "            T_star = U_star\n",
    "        continue \n",
    "\n",
    "    shap_rank, phi_mean = collect_shap_rank(epoch, B=SHAP_B)\n",
    "    ema_slow[:] = BETA_SLOW * ema_slow + (1 - BETA_SLOW) * phi_mean\n",
    "\n",
    "    cand, k_feat = pick_topk_by_coverage(ema_slow, cover=0.90, kmin=16, kmax=128)\n",
    "    \n",
    "    keep = list(set(U_star) & set(cand))\n",
    "    need = max(0, k_feat - len(keep))\n",
    "    add  = [i for i in cand if i not in keep][:need]\n",
    "    U_star = np.array(keep + add, dtype=int)\n",
    "\n",
    "    X_U = train_data_tensor[:, U_star].cpu().numpy()\n",
    "\n",
    "    y_target = model_np(train_data_tensor.to(device)).detach().cpu().numpy()\n",
    "\n",
    "    min_leaf = int(MIN_LEAF_FRAC * len(X_U)) + 10\n",
    "    new_tree, apl_true, w_snap = refresh_tree(\n",
    "            model.main,\n",
    "            X_U,\n",
    "            y_target,\n",
    "            snapshots,\n",
    "            device=next(model.parameters()).device,\n",
    "            seed=42,\n",
    "        )\n",
    "    accept_tree, metrics = score_tree_value(new_tree, dt_model,\n",
    "                                            X_U, y_target,\n",
    "                                            U_star, ema_slow,\n",
    "                                            epoch,\n",
    "                                            seed=42)\n",
    "    snapshots.append((w_snap, apl_true))\n",
    "    if accept_tree or epoch < 50:\n",
    "        T_star = U_star\n",
    "        tree_buffer.append(new_tree)\n",
    "        dt_prev = dt_model\n",
    "        dt_model = new_tree\n",
    "\n",
    "        leaf_prev = dt_prev.apply(X_U)\n",
    "        leaf_curr = dt_model.apply(X_U)\n",
    "        \n",
    "        nmi  = NMI(leaf_prev, leaf_curr)   \n",
    "    \n",
    "        def leaf_hash(tree):\n",
    "            if isinstance(tree, FrozenDepthTree):\n",
    "                paths = tree.get_paths()    \n",
    "            else:\n",
    "                paths = extract_paths(tree)      \n",
    "            out = {}\n",
    "            for leaf_id in sorted(paths.keys()):\n",
    "                out[tuple(paths[leaf_id])] = len(out)\n",
    "            return out         \n",
    "    \n",
    "                    \n",
    "        if len(tree_buffer) >= 2:\n",
    "            prev_hash = leaf_hash(tree_buffer[-2])\n",
    "            curr_hash = leaf_hash(tree_buffer[-1])\n",
    "            id_map = {curr: prev_hash[path] for path, curr in curr_hash.items() if path in prev_hash}\n",
    "        else:\n",
    "            prev_hash, curr_hash, id_map = {}, {}, {}\n",
    "        \n",
    "        new_n = len(curr_hash)\n",
    "        old_layer = model.cluster_layer\n",
    "        new_layer = nn.Linear(latent_dim, new_n).cuda()\n",
    "   \n",
    "        old_W, old_B = old_layer.weight, old_layer.bias           \n",
    "        new_W, new_B = new_layer.weight, new_layer.bias         \n",
    "        \n",
    "        with torch.no_grad():\n",
    "            IOU_TH = 0.25            \n",
    "            for new_path, new_id in curr_hash.items():\n",
    "                best_sim, best_old_id = 0., None\n",
    "                for old_path, old_id in prev_hash.items():\n",
    "                    sim = path_similarity(new_path, old_path, beta=1.0, eps=0.01)\n",
    "                    if sim > best_sim:\n",
    "                        best_sim, best_old_id = sim, old_id\n",
    "    \n",
    "                if best_old_id is not None:\n",
    "                    A = (leaf_prev == best_old_id)\n",
    "                    B = (leaf_curr == new_id)\n",
    "                    denom = (A | B).sum()\n",
    "                    iou = (A & B).sum() / (denom + 1e-8)\n",
    "                    if denom == 0: continue\n",
    "                    if iou >= IOU_TH:       \n",
    "                        wgt = best_sim * iou                \n",
    "                        new_W[new_id].mul_(1-wgt).add_(wgt * old_W[best_old_id])\n",
    "                        new_B[new_id].mul_(1-wgt).add_(wgt * old_B[best_old_id])\n",
    "\n",
    "        model.cluster_layer = new_layer\n",
    "\n",
    "    if len(snapshots)>INIT_TRAIN_SIZE:\n",
    "        update_surrogate(surrogate_bundle, snapshots,\n",
    "                                  device=next(model.parameters()).device)\n",
    "        if surrogate_bundle['model'] is not None:\n",
    "            w_live = snapshot_params(model.main).unsqueeze(0).to(device)\n",
    "            with torch.no_grad():\n",
    "                omega_hat = surrogate_bundle['model'](w_live).item()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python (triVae01)",
   "language": "python",
   "name": "trivae01"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
