{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "39f55327",
   "metadata": {},
   "source": [
    "## Section 1: Dependency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6f40f833",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: cuda\n",
      "GPU: Tesla V100-SXM2-32GB\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/users/yhung7/.local/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau\n",
    "from torch.cuda.amp import GradScaler, autocast\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import SplineTransformer, StandardScaler\n",
    "from sklearn.model_selection import train_test_split, KFold\n",
    "from sklearn.metrics import mean_squared_error, r2_score\n",
    "from itertools import combinations\n",
    "from tqdm.auto import tqdm\n",
    "from pathlib import Path\n",
    "import time\n",
    "import copy\n",
    "from dataclasses import dataclass\n",
    "from typing import List, Optional, Dict, Tuple\n",
    "import warnings\n",
    "import pandas as pd\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "# Device configuration\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Using device: {device}\")\n",
    "if torch.cuda.is_available():\n",
    "    print(f\"GPU: {torch.cuda.get_device_name(0)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea8940f2",
   "metadata": {},
   "source": [
    "## Section 2: Splines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "97041836",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SplineProcessor:\n",
    "    \n",
    "    def __init__(self, n_knots=10, boundary_knots=3, degree=3, custom=False):\n",
    "        self.n_knots = n_knots\n",
    "        self.boundary_knots = boundary_knots\n",
    "        self.degree = degree\n",
    "        self.custom = custom\n",
    "    \n",
    "    @staticmethod\n",
    "    def customized_knots(X: np.ndarray, boundary_knots: int, n_knots: int, \n",
    "                        degree: int, unif: bool = False) -> np.ndarray:\n",
    "        lwbnd, upbnd, Xstd = X.min(), X.max(), X.std()\n",
    "        \n",
    "        if unif:\n",
    "            return np.linspace(lwbnd, upbnd, n_knots).reshape(-1, 1)\n",
    "        \n",
    "        lw_knots = np.linspace(lwbnd - 0.1*Xstd, lwbnd + 0.1*Xstd, boundary_knots)\n",
    "        mid_knots = np.linspace(lwbnd + 0.1*Xstd, upbnd - 0.1*Xstd, \n",
    "                               n_knots - boundary_knots)\n",
    "        up_knots = np.linspace(upbnd - 0.1*Xstd, upbnd + 0.1*Xstd, boundary_knots)\n",
    "        knots = np.unique(np.concatenate([lw_knots, mid_knots, up_knots]))\n",
    "        return knots.reshape(-1, 1)\n",
    "    \n",
    "    def transform(self, X: np.ndarray, extrapolation: str = 'constant') -> np.ndarray:\n",
    "        \"\"\"Œspline transformation\"\"\"\n",
    "        X = np.asarray(X).reshape(-1, 1)\n",
    "        \n",
    "        if self.custom:\n",
    "            knots = self.customized_knots(X, self.boundary_knots, \n",
    "                                         self.n_knots, self.degree)\n",
    "            spline = SplineTransformer(knots=knots, degree=self.degree, \n",
    "                                      extrapolation=extrapolation)\n",
    "        else:\n",
    "            spline = SplineTransformer(n_knots=self.n_knots, degree=self.degree,\n",
    "                                      extrapolation=extrapolation)\n",
    "        \n",
    "        return spline.fit_transform(X)\n",
    "    \n",
    "    @staticmethod\n",
    "    def diff_penalty(num_coefs: int, order: int = 2) -> np.ndarray:\n",
    "        return np.diff(np.eye(num_coefs), n=order, axis=0)\n",
    "    \n",
    "    def smoothing_matrix(self, x: np.ndarray, lam: float = 0.1) -> np.ndarray:\n",
    "        B = self.transform(x)\n",
    "        D = self.diff_penalty(B.shape[1], order=2)\n",
    "        \n",
    "        BTB = B.T @ B\n",
    "        DTD = D.T @ D\n",
    "        S = B @ np.linalg.solve(BTB + lam * DTD, B.T)\n",
    "        return S\n",
    "    \n",
    "def spline_transform(X, n_knots=None, boundary_knots=None, degree=3, \n",
    "                    custom=False, extrapolation='constant'):\n",
    "    processor = SplineProcessor(n_knots, boundary_knots, degree, custom)\n",
    "    return processor.transform(X, extrapolation)\n",
    "\n",
    "def PS_smoothing_matrix(x, lam, n_knots, degree, boundary_knots, custom):\n",
    "    processor = SplineProcessor(n_knots, boundary_knots, degree, custom)\n",
    "    return processor.smoothing_matrix(x, lam)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fb37c66",
   "metadata": {},
   "source": [
    "## SAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fe9bc2f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def SAM(X, y, lam = 0, alpha=0.25, max_iter=10, tol=1e-6, ftol = 1e-3, n_knots=10, boundary_knots = 3, degree=3, custom = False):\n",
    "    n_samples, n_features = X.size()\n",
    "    whole_feature = set(list(range(n_features)))\n",
    "    feature_space = (list(range(n_features)))\n",
    "    flag = [True] * n_features\n",
    "    f = torch.zeros((n_samples, n_features))\n",
    "    R = torch.clone(y)\n",
    "    \n",
    "    for _ in range(max_iter):\n",
    "        f_old = torch.clone(f)\n",
    "        flag = [True] * len(feature_space)\n",
    "        for j in range(len(feature_space)):\n",
    "            u_space_idx = [feature_space[j]]\n",
    "            res_space_idx = list(whole_feature-set(u_space_idx))\n",
    "            Res = R - f[:, res_space_idx].sum(axis=1)\n",
    "            Res = torch.FloatTensor(Res)\n",
    "            \n",
    "            PS_matrix = PS_smoothing_matrix(X[:, j], lam = lam, n_knots = n_knots, degree = degree, boundary_knots = boundary_knots, custom = custom)\n",
    "            PS_matrix = torch.FloatTensor(PS_matrix)\n",
    "\n",
    "            P_j = PS_matrix @ Res\n",
    "            s_j = torch.sqrt(torch.mean(P_j**2))\n",
    "            if s_j > alpha and flag[j]:\n",
    "                f[:, feature_space[j]] =  (1 - alpha / s_j) * P_j\n",
    "            else:\n",
    "                flag[j] = False\n",
    "                f[:, feature_space[j]] = 0\n",
    "            \n",
    "            del PS_matrix\n",
    "            \n",
    "        tfs = []\n",
    "        for b in range(len(flag)):\n",
    "            if flag[b]:\n",
    "                tfs.append(feature_space[b])\n",
    "        feature_space = tfs\n",
    "\n",
    "        if (torch.sum(torch.square(f - f_old)) < tol):\n",
    "            print(f\"Alpha: {alpha:.2f} | Convergence.\")\n",
    "\n",
    "            return f\n",
    "\n",
    "    print(f\"Alpha: {alpha:.2f} | Need more iteration.\")\n",
    "\n",
    "    return f\n",
    "\n",
    "def train_SAM(X, y, alpha_list, max_iter, nk, nb, custom):\n",
    "\n",
    "    Max_L1 = torch.zeros((len(alpha_list)))\n",
    "    component_list = {}\n",
    "    result = {}\n",
    "\n",
    "    for i in range(len(alpha_list)):\n",
    "        f = SAM(X, y, lam = 0.1, alpha = alpha_list[i], max_iter = max_iter, tol=1e-6, ftol = 1e-3, n_knots=nk, boundary_knots = nb, degree=3, custom = custom)\n",
    "        # Identify the non-active features among iteration\n",
    "\n",
    "        nonact_idx = torch.where(torch.sum(torch.square(f), axis = 0) == 0)[0]\n",
    "        Max_L1[i] = torch.sum(torch.norm(f, dim = 0))\n",
    "\n",
    "        component_list[i] = f\n",
    "        component_list[i][:, nonact_idx] = 0\n",
    "\n",
    "    GCV_list, loc, active_dict = estimate_gcv(alpha_list, component_list, X, y)\n",
    "    Max_L1 = torch.max(Max_L1)\n",
    "    result['component'] = component_list\n",
    "    result['GCV'] = GCV_list\n",
    "    result['opt_loc'] = loc\n",
    "    result['opt_var'] = active_dict\n",
    "    result['Max_L1'] = Max_L1\n",
    "    result['alpha'] = alpha_list\n",
    "    \n",
    "    return result\n",
    "\n",
    "## Evaluation best feature \n",
    "\n",
    "def estimate_gcv(alpha, comp, X, y):\n",
    "\n",
    "    def estimate_sigma2(y, y_pred, edf_total):\n",
    "        n = len(y)\n",
    "        rss = torch.sum(torch.square(y - y_pred))\n",
    "        if edf_total > n:\n",
    "            return rss\n",
    "        else:\n",
    "            return rss / (n - edf_total)\n",
    "\n",
    "    def compute_cp(y, y_pred, edf_total, sigma2_hat):\n",
    "        n = len(y)\n",
    "        rss = torch.sum(torch.square(y - y_pred))\n",
    "        cp = (rss / n) + 2 * (sigma2_hat / n) * edf_total\n",
    "        return cp\n",
    "\n",
    "    CP = torch.zeros_like(alpha)\n",
    "    n_samples = y.size()[0]\n",
    "    Critical_point = None\n",
    "    \n",
    "    for i in range(len(alpha)-1, -1, -1):\n",
    "        pred_y = comp[i].sum(axis = 1)\n",
    "        \n",
    "        active_set = torch.where(torch.norm(comp[i], p = 1, dim=0) != 0)[0].tolist()\n",
    "        effective_df = 0\n",
    "        for idx in range(len(active_set)):\n",
    "            PS_matrix = PS_smoothing_matrix(X[:, active_set[idx]], lam = 0.1, n_knots = 10, degree = 3, boundary_knots = 3, custom = False)    \n",
    "            #effective_df += (np.trace(PS_matrix)/n_samples)\n",
    "            effective_df += np.trace(PS_matrix)\n",
    "\n",
    "            del PS_matrix\n",
    "\n",
    "        rss = torch.sum(torch.square(y - pred_y))\n",
    "        # Condition for over-fitting\n",
    "        if effective_df > n_samples:\n",
    "            sigma2_hat = rss\n",
    "            if Critical_point == None:\n",
    "                Critical_point = i\n",
    "        else:\n",
    "            sigma2_hat = estimate_sigma2(y, pred_y, effective_df)\n",
    "\n",
    "        cp_value = compute_cp(y, pred_y, effective_df, sigma2_hat)\n",
    "        \n",
    "        CP[i] = cp_value\n",
    "\n",
    "    # Simplify the plot\n",
    "    CP[:Critical_point] = torch.max(CP[Critical_point:])   \n",
    "    minCP = torch.where(CP == torch.min(CP))[0]\n",
    "    optimalloc_ = minCP.item() if len(minCP) == 1 else minCP[0].item()\n",
    "    optimalset_ = torch.where(torch.norm(comp[optimalloc_], p = 1, dim=0) != 0)[0].tolist()\n",
    "    \n",
    "    return CP, optimalloc_, optimalset_\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59e378bb",
   "metadata": {},
   "source": [
    "## Modeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c21238ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ImprovedFeatureNet(nn.Module):\n",
    "    \"\"\"æ”¹é€²çš„ç‰¹å¾µå­ç¶²è·¯ - åŠ å…¥BN, Dropout, å¯é¸Residual connections\"\"\"\n",
    "    \n",
    "    def __init__(self, input_dim: int, hidden_dims: List[int], \n",
    "                 use_bn: bool = True, dropout: float = 0.1, \n",
    "                 activation: str = 'relu', use_residual: bool = False):\n",
    "        super().__init__()\n",
    "        self.use_residual = use_residual and (input_dim == hidden_dims[-1])\n",
    "        \n",
    "        layers = []\n",
    "        in_dim = input_dim\n",
    "        \n",
    "        for i, h_dim in enumerate(hidden_dims):\n",
    "            layers.append(nn.Linear(in_dim, h_dim))\n",
    "            \n",
    "            if use_bn:\n",
    "                layers.append(nn.BatchNorm1d(h_dim))\n",
    "            \n",
    "            # Activation\n",
    "            if activation == 'relu':\n",
    "                layers.append(nn.ReLU())\n",
    "            elif activation == 'leaky_relu':\n",
    "                layers.append(nn.LeakyReLU(0.2))\n",
    "            elif activation == 'gelu':\n",
    "                layers.append(nn.GELU())\n",
    "            \n",
    "            if dropout > 0 and i < len(hidden_dims) - 1:\n",
    "                layers.append(nn.Dropout(dropout))\n",
    "            \n",
    "            in_dim = h_dim\n",
    "        \n",
    "        # Output layer\n",
    "        layers.append(nn.Linear(in_dim, 1))\n",
    "        self.net = nn.Sequential(*layers)\n",
    "        \n",
    "        # Weight initialization\n",
    "        self._init_weights()\n",
    "    \n",
    "    def _init_weights(self):\n",
    "        \"\"\"Kaimingåˆå§‹åŒ–\"\"\"\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Linear):\n",
    "                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n",
    "                if m.bias is not None:\n",
    "                    nn.init.constant_(m.bias, 0)\n",
    "            elif isinstance(m, nn.BatchNorm1d):\n",
    "                nn.init.constant_(m.weight, 1)\n",
    "                nn.init.constant_(m.bias, 0)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        out = self.net(x)\n",
    "        if self.use_residual:\n",
    "            out = out + x\n",
    "        return out\n",
    "\n",
    "\n",
    "class ImprovedAdditiveModel(nn.Module):\n",
    "    \"\"\"æ”¹é€²çš„ADNN - æ”¯æ´component-wiseæ­£å‰‡åŒ–\"\"\"\n",
    "    \n",
    "    def __init__(self, index_list: List[List[int]], hidden_dims: List[int],\n",
    "                 output_dim: int = 1, use_bn: bool = True, dropout: float = 0.1,\n",
    "                 activation: str = 'relu', l1_reg: float = 0.0):\n",
    "        super().__init__()\n",
    "        self.index_list = index_list\n",
    "        self.l1_reg = l1_reg\n",
    "        \n",
    "        # Build feature networks\n",
    "        self.feature_nets = nn.ModuleList([\n",
    "            ImprovedFeatureNet(len(indices), hidden_dims, use_bn, dropout, activation)\n",
    "            for indices in index_list\n",
    "        ])\n",
    "        \n",
    "        # Combiner with learnable weights\n",
    "        self.combiner = nn.Linear(len(index_list), output_dim, bias=True)\n",
    "        self.hook = {}\n",
    "    \n",
    "    def forward(self, X):\n",
    "        individual_outputs = []\n",
    "        \n",
    "        for indices, net in zip(self.index_list, self.feature_nets):\n",
    "            x_sub = X[:, indices]\n",
    "            out = net(x_sub)\n",
    "            individual_outputs.append(out)\n",
    "        \n",
    "        combined = torch.cat(individual_outputs, dim=1)\n",
    "        self.hook['acomp'] = combined  # For interpretability\n",
    "        return self.combiner(combined)\n",
    "    \n",
    "    def get_l1_loss(self):\n",
    "        \"\"\"è¨ˆç®—L1æ­£å‰‡åŒ–æå¤±\"\"\"\n",
    "        if self.l1_reg <= 0:\n",
    "            return 0.0\n",
    "        \n",
    "        l1_loss = 0.0\n",
    "        for param in self.parameters():\n",
    "            l1_loss += torch.abs(param).sum()\n",
    "        return self.l1_reg * l1_loss\n",
    "\n",
    "\n",
    "class ImprovedDNNBaseline(nn.Module):\n",
    "    \"\"\"æ”¹é€²çš„DNN baseline - åŠ å…¥ç¾ä»£åŒ–æŠ€å·§\"\"\"\n",
    "    \n",
    "    def __init__(self, input_dim: int, hidden_dims: List[int] = [128, 64, 32],\n",
    "                 use_bn: bool = True, dropout: float = 0.2, activation: str = 'gelu'):\n",
    "        super().__init__()\n",
    "        \n",
    "        layers = []\n",
    "        in_dim = input_dim\n",
    "        \n",
    "        for i, h_dim in enumerate(hidden_dims):\n",
    "            layers.append(nn.Linear(in_dim, h_dim))\n",
    "            \n",
    "            if use_bn:\n",
    "                layers.append(nn.BatchNorm1d(h_dim))\n",
    "            \n",
    "            if activation == 'relu':\n",
    "                layers.append(nn.ReLU())\n",
    "            elif activation == 'leaky_relu':\n",
    "                layers.append(nn.LeakyReLU(0.2))\n",
    "            elif activation == 'gelu':\n",
    "                layers.append(nn.GELU())\n",
    "            \n",
    "            if dropout > 0:\n",
    "                layers.append(nn.Dropout(dropout))\n",
    "            \n",
    "            in_dim = h_dim\n",
    "        \n",
    "        layers.append(nn.Linear(in_dim, 1))\n",
    "        self.net = nn.Sequential(*layers)\n",
    "        self._init_weights()\n",
    "    \n",
    "    def _init_weights(self):\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Linear):\n",
    "                nn.init.kaiming_normal_(m.weight)\n",
    "                if m.bias is not None:\n",
    "                    nn.init.constant_(m.bias, 0)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.net(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bc7e8991",
   "metadata": {},
   "source": [
    "## Training setting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "512058b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class TrainerConfig:\n",
    "    \"\"\"è¨“ç·´é…ç½® - é›†ä¸­ç®¡ç†æ‰€æœ‰è¶…åƒæ•¸\"\"\"\n",
    "    \n",
    "    # Basic training\n",
    "    n_epochs: int = 10000\n",
    "    batch_size: int = 64\n",
    "    learning_rate: float = 1e-3\n",
    "    weight_decay: float = 1e-5\n",
    "    \n",
    "    # Learning rate schedule\n",
    "    scheduler_type: str = 'cosine'  # 'cosine', 'plateau', 'warmup_cosine'\n",
    "    warmup_epochs: int = 100\n",
    "    min_lr: float = 1e-6\n",
    "    \n",
    "    # Early stopping\n",
    "    early_stopping_patience: int = 50\n",
    "    early_stopping_delta: float = 0.0\n",
    "    \n",
    "    # Optimization tricks\n",
    "    gradient_clip_value: float = 1.0\n",
    "    use_mixed_precision: bool = True\n",
    "    \n",
    "    # Checkpoint\n",
    "    save_checkpoint_every: int = 500\n",
    "    checkpoint_dir: str = './checkpoints'\n",
    "    save_best_only: bool = True\n",
    "    \n",
    "    # Logging\n",
    "    log_every: int = 100\n",
    "    use_tensorboard: bool = False\n",
    "    \n",
    "    # Regularization\n",
    "    label_smoothing: float = 0.0\n",
    "\n",
    "\n",
    "class WarmupCosineScheduler:\n",
    "    \"\"\"Warm-up + Cosine Annealingå­¸ç¿’çŽ‡èª¿åº¦å™¨\"\"\"\n",
    "    \n",
    "    def __init__(self, optimizer, warmup_epochs: int, total_epochs: int,\n",
    "                 base_lr: float, min_lr: float = 1e-6):\n",
    "        self.optimizer = optimizer\n",
    "        self.warmup_epochs = warmup_epochs\n",
    "        self.total_epochs = total_epochs\n",
    "        self.base_lr = base_lr\n",
    "        self.min_lr = min_lr\n",
    "        self.current_epoch = 0\n",
    "    \n",
    "    def step(self):\n",
    "        \"\"\"æ›´æ–°å­¸ç¿’çŽ‡\"\"\"\n",
    "        if self.current_epoch < self.warmup_epochs:\n",
    "            # Linear warmup\n",
    "            lr = self.base_lr * (self.current_epoch + 1) / self.warmup_epochs\n",
    "        else:\n",
    "            # Cosine annealing\n",
    "            progress = (self.current_epoch - self.warmup_epochs) / \\\n",
    "                      (self.total_epochs - self.warmup_epochs)\n",
    "            lr = self.min_lr + 0.5 * (self.base_lr - self.min_lr) * \\\n",
    "                (1 + np.cos(np.pi * progress))\n",
    "        \n",
    "        for param_group in self.optimizer.param_groups:\n",
    "            param_group['lr'] = lr\n",
    "        \n",
    "        self.current_epoch += 1\n",
    "        return lr\n",
    "\n",
    "\n",
    "class EarlyStopper:\n",
    "    \"\"\"æ”¹é€²çš„Early Stopping\"\"\"\n",
    "    \n",
    "    def __init__(self, patience: int = 50, min_delta: float = 0.0, mode: str = 'min'):\n",
    "        self.patience = patience\n",
    "        self.min_delta = min_delta\n",
    "        self.mode = mode\n",
    "        self.counter = 0\n",
    "        self.best_score = float('inf') if mode == 'min' else float('-inf')\n",
    "        self.early_stop_triggered = False\n",
    "    \n",
    "    def __call__(self, score: float) -> bool:\n",
    "        if self.mode == 'min':\n",
    "            improved = score < (self.best_score - self.min_delta)\n",
    "        else:\n",
    "            improved = score > (self.best_score + self.min_delta)\n",
    "        \n",
    "        if improved:\n",
    "            self.best_score = score\n",
    "            self.counter = 0\n",
    "            return False\n",
    "        else:\n",
    "            self.counter += 1\n",
    "            if self.counter >= self.patience:\n",
    "                self.early_stop_triggered = True\n",
    "                return True\n",
    "        return False\n",
    "\n",
    "\n",
    "class ModelTrainer:\n",
    "    \"\"\"çµ±ä¸€çš„è¨“ç·´å™¨ - æ”¯æ´æ‰€æœ‰ç¾ä»£åŒ–è¨“ç·´æŠ€å·§\"\"\"\n",
    "    \n",
    "    def __init__(self, model: nn.Module, config: TrainerConfig):\n",
    "        self.model = model.to(device)\n",
    "        self.config = config\n",
    "        \n",
    "        # Optimizer\n",
    "        self.optimizer = optim.AdamW(\n",
    "            model.parameters(),\n",
    "            lr=config.learning_rate,\n",
    "            weight_decay=config.weight_decay\n",
    "        )\n",
    "        \n",
    "        # Loss function\n",
    "        self.criterion = nn.MSELoss()\n",
    "        \n",
    "        # Learning rate scheduler\n",
    "        self.scheduler = self._create_scheduler()\n",
    "        \n",
    "        # Early stopping\n",
    "        self.early_stopper = EarlyStopper(\n",
    "            patience=config.early_stopping_patience,\n",
    "            min_delta=config.early_stopping_delta\n",
    "        )\n",
    "        \n",
    "        # Mixed precision\n",
    "        self.scaler = GradScaler() if config.use_mixed_precision else None\n",
    "        \n",
    "        # Checkpoint directory\n",
    "        self.checkpoint_dir = Path(config.checkpoint_dir)\n",
    "        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)\n",
    "        \n",
    "        # Training history\n",
    "        self.history = {\n",
    "            'train_loss': [],\n",
    "            'val_loss': [],\n",
    "            'lr': []\n",
    "        }\n",
    "        \n",
    "        self.best_val_loss = float('inf')\n",
    "        self.best_model_state = None\n",
    "    \n",
    "    def _create_scheduler(self):\n",
    "        \"\"\"å‰µå»ºå­¸ç¿’çŽ‡èª¿åº¦å™¨\"\"\"\n",
    "        if self.config.scheduler_type == 'cosine':\n",
    "            return CosineAnnealingLR(\n",
    "                self.optimizer,\n",
    "                T_max=self.config.n_epochs,\n",
    "                eta_min=self.config.min_lr\n",
    "            )\n",
    "        elif self.config.scheduler_type == 'plateau':\n",
    "            return ReduceLROnPlateau(\n",
    "                self.optimizer,\n",
    "                mode='min',\n",
    "                factor=0.5,\n",
    "                patience=20,\n",
    "                min_lr=self.config.min_lr\n",
    "            )\n",
    "        elif self.config.scheduler_type == 'warmup_cosine':\n",
    "            return WarmupCosineScheduler(\n",
    "                self.optimizer,\n",
    "                warmup_epochs=self.config.warmup_epochs,\n",
    "                total_epochs=self.config.n_epochs,\n",
    "                base_lr=self.config.learning_rate,\n",
    "                min_lr=self.config.min_lr\n",
    "            )\n",
    "        else:\n",
    "            return None\n",
    "    \n",
    "    def train_epoch(self, train_loader) -> float:\n",
    "        \"\"\"è¨“ç·´ä¸€å€‹epoch\"\"\"\n",
    "        self.model.train()\n",
    "        total_loss = 0.0\n",
    "        num_batches = 0\n",
    "        \n",
    "        for X_batch, y_batch in train_loader:\n",
    "            X_batch = X_batch.to(device)\n",
    "            y_batch = y_batch.to(device)\n",
    "            \n",
    "            self.optimizer.zero_grad()\n",
    "            \n",
    "            # Mixed precision training\n",
    "            if self.scaler is not None:\n",
    "                with autocast():\n",
    "                    y_pred = self.model(X_batch)\n",
    "                    loss = self.criterion(y_pred, y_batch)\n",
    "                    \n",
    "                    # Add L1 regularization if applicable\n",
    "                    if hasattr(self.model, 'get_l1_loss'):\n",
    "                        loss += self.model.get_l1_loss()\n",
    "                \n",
    "                self.scaler.scale(loss).backward()\n",
    "                \n",
    "                # Gradient clipping\n",
    "                if self.config.gradient_clip_value > 0:\n",
    "                    self.scaler.unscale_(self.optimizer)\n",
    "                    torch.nn.utils.clip_grad_norm_(\n",
    "                        self.model.parameters(),\n",
    "                        self.config.gradient_clip_value\n",
    "                    )\n",
    "                \n",
    "                self.scaler.step(self.optimizer)\n",
    "                self.scaler.update()\n",
    "            else:\n",
    "                y_pred = self.model(X_batch)\n",
    "                loss = self.criterion(y_pred, y_batch)\n",
    "                \n",
    "                if hasattr(self.model, 'get_l1_loss'):\n",
    "                    loss += self.model.get_l1_loss()\n",
    "                \n",
    "                loss.backward()\n",
    "                \n",
    "                if self.config.gradient_clip_value > 0:\n",
    "                    torch.nn.utils.clip_grad_norm_(\n",
    "                        self.model.parameters(),\n",
    "                        self.config.gradient_clip_value\n",
    "                    )\n",
    "                \n",
    "                self.optimizer.step()\n",
    "            \n",
    "            total_loss += loss.item()\n",
    "            num_batches += 1\n",
    "        \n",
    "        return total_loss / num_batches\n",
    "    \n",
    "    @torch.no_grad()\n",
    "    def validate(self, val_loader) -> float:\n",
    "        \"\"\"é©—è­‰\"\"\"\n",
    "        self.model.eval()\n",
    "        total_loss = 0.0\n",
    "        num_batches = 0\n",
    "        \n",
    "        for X_batch, y_batch in val_loader:\n",
    "            X_batch = X_batch.to(device)\n",
    "            y_batch = y_batch.to(device)\n",
    "            \n",
    "            y_pred = self.model(X_batch)\n",
    "            loss = self.criterion(y_pred, y_batch)\n",
    "            \n",
    "            total_loss += loss.item()\n",
    "            num_batches += 1\n",
    "        \n",
    "        return total_loss / num_batches\n",
    "    \n",
    "    def save_checkpoint(self, epoch: int, val_loss: float, \n",
    "                       filename: Optional[str] = None):\n",
    "        \"\"\"ä¿å­˜checkpoint\"\"\"\n",
    "        if filename is None:\n",
    "            filename = f'checkpoint_epoch_{epoch}.pt'\n",
    "        \n",
    "        checkpoint = {\n",
    "            'epoch': epoch,\n",
    "            'model_state_dict': self.model.state_dict(),\n",
    "            'optimizer_state_dict': self.optimizer.state_dict(),\n",
    "            'val_loss': val_loss,\n",
    "            'train_loss': self.history['train_loss'][-1] if self.history['train_loss'] else None,\n",
    "            'config': self.config,\n",
    "            'history': self.history\n",
    "        }\n",
    "        \n",
    "        if self.scaler is not None:\n",
    "            checkpoint['scaler_state_dict'] = self.scaler.state_dict()\n",
    "        \n",
    "        filepath = self.checkpoint_dir / filename\n",
    "        torch.save(checkpoint, filepath)\n",
    "        return filepath\n",
    "    \n",
    "    def load_checkpoint(self, filepath: str):\n",
    "        \"\"\"è¼‰å…¥checkpoint\"\"\"\n",
    "        checkpoint = torch.load(filepath, map_location=device)\n",
    "        self.model.load_state_dict(checkpoint['model_state_dict'])\n",
    "        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
    "        \n",
    "        if self.scaler is not None and 'scaler_state_dict' in checkpoint:\n",
    "            self.scaler.load_state_dict(checkpoint['scaler_state_dict'])\n",
    "        \n",
    "        self.history = checkpoint.get('history', self.history)\n",
    "        return checkpoint['epoch']\n",
    "    \n",
    "    def fit(self, X_train: torch.Tensor, y_train: torch.Tensor,\n",
    "            X_val: torch.Tensor, y_val: torch.Tensor):\n",
    "        \"\"\"å®Œæ•´è¨“ç·´æµç¨‹\"\"\"\n",
    "        \n",
    "        # Create data loaders\n",
    "        train_dataset = torch.utils.data.TensorDataset(X_train, y_train)\n",
    "        train_loader = torch.utils.data.DataLoader(\n",
    "            train_dataset,\n",
    "            batch_size=self.config.batch_size,\n",
    "            shuffle=True,\n",
    "            num_workers=0\n",
    "        )\n",
    "        \n",
    "        val_dataset = torch.utils.data.TensorDataset(X_val, y_val)\n",
    "        val_loader = torch.utils.data.DataLoader(\n",
    "            val_dataset,\n",
    "            batch_size=self.config.batch_size * 2,\n",
    "            shuffle=False,\n",
    "            num_workers=0\n",
    "        )\n",
    "        \n",
    "        # Training loop with progress bar\n",
    "        pbar = tqdm(range(self.config.n_epochs), desc='Training')\n",
    "        start_time = time.time()\n",
    "        \n",
    "        for epoch in pbar:\n",
    "            # Train\n",
    "            train_loss = self.train_epoch(train_loader)\n",
    "            \n",
    "            # Validate\n",
    "            val_loss = self.validate(val_loader)\n",
    "            \n",
    "            # Update scheduler\n",
    "            current_lr = self.optimizer.param_groups[0]['lr']\n",
    "            if isinstance(self.scheduler, ReduceLROnPlateau):\n",
    "                self.scheduler.step(val_loss)\n",
    "            elif self.scheduler is not None:\n",
    "                self.scheduler.step()\n",
    "            \n",
    "            # Record history\n",
    "            self.history['train_loss'].append(train_loss)\n",
    "            self.history['val_loss'].append(val_loss)\n",
    "            self.history['lr'].append(current_lr)\n",
    "            \n",
    "            # Update progress bar\n",
    "            pbar.set_postfix({\n",
    "                'train_loss': f'{train_loss:.6f}',\n",
    "                'val_loss': f'{val_loss:.6f}',\n",
    "                'lr': f'{current_lr:.6f}'\n",
    "            })\n",
    "            \n",
    "            # Save best model\n",
    "            if val_loss < self.best_val_loss:\n",
    "                self.best_val_loss = val_loss\n",
    "                self.best_model_state = copy.deepcopy(self.model.state_dict())\n",
    "                \n",
    "                if self.config.save_best_only:\n",
    "                    self.save_checkpoint(epoch, val_loss, 'best_model.pt')\n",
    "            \n",
    "            # Periodic checkpoint\n",
    "            if (epoch + 1) % self.config.save_checkpoint_every == 0:\n",
    "                self.save_checkpoint(epoch, val_loss)\n",
    "            \n",
    "            # Early stopping\n",
    "            if self.early_stopper(val_loss):\n",
    "                print(f\"\\nEpoch {epoch+1}\")\n",
    "                break\n",
    "        \n",
    "        # Load best model\n",
    "        if self.best_model_state is not None:\n",
    "            self.model.load_state_dict(self.best_model_state)\n",
    "        \n",
    "        total_time = time.time() - start_time\n",
    "        \n",
    "        print(f\"Total Running Time“: {total_time:.2f}\")\n",
    "        print(f\"Validation Loss: {self.best_val_loss:.6f}\")\n",
    "        \n",
    "        return {\n",
    "            'best_val_loss': self.best_val_loss,\n",
    "            'total_time': total_time,\n",
    "            'history': self.history\n",
    "        }\n",
    "    \n",
    "def train_model(model, X_train, y_train, X_val, y_val, \n",
    "               config: Optional[TrainerConfig] = None, **kwargs):\n",
    "    \"\"\"\n",
    "    ä¾¿æ·çš„è¨“ç·´å‡½æ•¸\n",
    "    \n",
    "    Args:\n",
    "        model: PyTorchæ¨¡åž‹\n",
    "        X_train, y_train: è¨“ç·´æ•¸æ“š\n",
    "        X_val, y_val: é©—è­‰æ•¸æ“š\n",
    "        config: TrainerConfigå°è±¡ (å¯é¸)\n",
    "        **kwargs: è¦†è“‹configçš„åƒæ•¸\n",
    "    \"\"\"\n",
    "    if config is None:\n",
    "        config = TrainerConfig(**kwargs)\n",
    "    else:\n",
    "        # Override config with kwargs\n",
    "        for k, v in kwargs.items():\n",
    "            if hasattr(config, k):\n",
    "                setattr(config, k, v)\n",
    "    \n",
    "    trainer = ModelTrainer(model, config)\n",
    "    results = trainer.fit(X_train, y_train, X_val, y_val)\n",
    "    \n",
    "    return trainer, results\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99cf947a",
   "metadata": {},
   "source": [
    "## Helpful Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "481bccb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def report(alpha, result, truey):\n",
    "\n",
    "    summary = {}\n",
    "    act_idx_list = []\n",
    "    rmse_list = torch.zeros((len(alpha)))\n",
    "    for i in range(len(alpha)):\n",
    "        active_idx = torch.where( torch.norm(result['component'][i], dim = 0) != 0)[0]\n",
    "        act_idx_list.append(torch.where( torch.norm(result['component'][i], dim = 0) != 0)[0])\n",
    "        rmse_list[i] = np.sqrt(mean_squared_error(result['component'][i].sum(dim = 1), truey))\n",
    "\n",
    "        print(f\"Alpha: {alpha[i]:3f} | # Active index: {active_idx.size()[0]} | RMSE: {rmse_list[i]:3f} |\")\n",
    "\n",
    "    summary['alpha'] = alpha\n",
    "    summary['act_idx'] = act_idx_list\n",
    "    summary['rmse'] = rmse_list\n",
    "\n",
    "    return summary\n",
    "\n",
    "def extract_active_features(X: torch.tensor, active_idx: list[int]) -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    Extract active features from X based on active indices.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    X : np.ndarray\n",
    "        Data matrix of shape (n_samples, n_features)\n",
    "    active_idx : list[int]\n",
    "        Indices of active features\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    pd.DataFrame\n",
    "        DataFrame with columns named x1, x2, ..., for active features\n",
    "    \"\"\"\n",
    "    data = pd.DataFrame({f\"x{idx+1}\": X[:, idx] for idx in active_idx})\n",
    "    interaction = list(combinations(active_idx, 2))\n",
    "    return data, interaction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70f3d0bc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb9c9e5b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f8af6c87",
   "metadata": {},
   "source": [
    "## Loading Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "166678af",
   "metadata": {},
   "outputs": [],
   "source": [
    "from os.path import join as pjoin, exists as pexists\n",
    "\n",
    "import os\n",
    "import glob\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "class RealDatasetLoader:\n",
    "    \"\"\"Utility loader for several real datasets used in experiments.\n",
    "\n",
    "    Methods return a dict containing at minimum:\n",
    "      - X_train: pandas.DataFrame\n",
    "      - y_train: numpy.ndarray\n",
    "      - X_test: pandas.DataFrame\n",
    "      - y_test: numpy.ndarray\n",
    "      - feature_names: list\n",
    "      - problem: 'regression' (for these datasets)\n",
    "\n",
    "    All methods accept a `normalize` flag which, if True, will z-score\n",
    "    the features using the training set statistics. For compatibility with\n",
    "    existing notebooks this loader returns DataFrames for X (so callers\n",
    "    can later call `.to_numpy()` before converting to torch tensors).\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, base_path=None):\n",
    "        if base_path is None:\n",
    "            # default to repository Real-dataset folder (one level up from src)\n",
    "            base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'Real-dataset'))\n",
    "        self.base_path = base_path\n",
    "\n",
    "    def _normalize(self, X_train, X_test):\n",
    "        # X_train/test are pandas DataFrames\n",
    "        x_max = X_train.max(axis=0)\n",
    "        x_min = X_train.min(axis=0)\n",
    "        X_train_n = (X_train - x_min) / (x_max - x_min)\n",
    "        X_test_n = (X_test - x_min) /  (x_max - x_min)\n",
    "        return X_train_n, X_test_n\n",
    "\n",
    "    def _standardize(self, X_train, X_test):\n",
    "        # X_train/test are pandas DataFrames\n",
    "        mu = X_train.mean(axis=0)\n",
    "        std = X_train.std(axis=0).replace(0, 1.0)\n",
    "        X_train_n = (X_train - mu) / std\n",
    "        X_test_n = (X_test - mu) / std\n",
    "        return X_train_n, X_test_n\n",
    "    \n",
    "    def load_wine(self, fold=0, label_col='quality', infile=None, normalize=False, standardize = False):\n",
    "        \"\"\"Load winequality-white dataset and return train/test split by fold files.\n",
    "\n",
    "        Expects files under `<base_path>/wine/`:\n",
    "          - `winequality-white.csv` (semicolon-separated)\n",
    "          - `train{fold}.txt` and `test{fold}.txt` with row indices\n",
    "        \"\"\"\n",
    "        folder = os.path.join(self.base_path, 'wine')\n",
    "        if infile is None:\n",
    "            infile = os.path.join(folder, 'winequality-white.csv')\n",
    "        if not os.path.exists(infile):\n",
    "            raise FileNotFoundError(f\"Wine file not found at {infile}\")\n",
    "\n",
    "        # winequality files use ';' as separator\n",
    "        df = pd.read_csv(infile, sep=';')\n",
    "        if label_col not in df.columns:\n",
    "            raise KeyError(f\"Label column '{label_col}' not found in {infile}\")\n",
    "\n",
    "        train_idx_file = os.path.join(folder, f'train{fold}.txt')\n",
    "        test_idx_file = os.path.join(folder, f'test{fold}.txt')\n",
    "        if not os.path.exists(train_idx_file) or not os.path.exists(test_idx_file):\n",
    "            raise FileNotFoundError('Train/test split files not found under %s' % folder)\n",
    "\n",
    "        train_idx = pd.read_csv(train_idx_file, header=None)[0].values\n",
    "        test_idx = pd.read_csv(test_idx_file, header=None)[0].values\n",
    "\n",
    "        feature_cols = [c for c in df.columns if c != label_col]\n",
    "        X = df[feature_cols]\n",
    "        y = df[label_col].values.astype(np.float32)\n",
    "\n",
    "        X_train = X.iloc[train_idx].reset_index(drop=True)\n",
    "        X_test = X.iloc[test_idx].reset_index(drop=True)\n",
    "        y_train = y[train_idx]\n",
    "        y_test = y[test_idx]\n",
    "\n",
    "        if normalize:\n",
    "            X_train, X_test = self._normalize(X_train, X_test)\n",
    "        elif standardize:\n",
    "            X_train, X_test = self._standardize(X_train, X_test)\n",
    "        else:\n",
    "            pass\n",
    "\n",
    "        return dict(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test,\n",
    "                    feature_names=list(feature_cols), problem='regression')\n",
    "\n",
    "    def load_bikeshare(self, fold=0, hour=True, infile=None, normalize=False, standardize = False):\n",
    "        \"\"\"Load bikeshare dataset. If `hour` True uses `hour.csv`, otherwise `day.csv`.\n",
    "\n",
    "        Expects split indices under `<base_path>/bikeshare/train{fold}.txt` etc.\n",
    "        \"\"\"\n",
    "        folder = os.path.join(self.base_path, 'bikeshare')\n",
    "        fname = 'hour.csv' if hour else 'day.csv'\n",
    "        if infile is None:\n",
    "            infile = os.path.join(folder, fname)\n",
    "        if not os.path.exists(infile):\n",
    "            raise FileNotFoundError(f\"Bikeshare file not found at {infile}\")\n",
    "\n",
    "        df = pd.read_csv(infile).set_index('instant')\n",
    "        train_cols = ['season', 'yr', 'mnth', 'hr', 'holiday', 'weekday',\n",
    "                      'workingday', 'weathersit', 'temp', 'atemp', 'hum', 'windspeed']\n",
    "        label = 'cnt'\n",
    "        for c in train_cols + [label]:\n",
    "            if c not in df.columns:\n",
    "                raise KeyError(f\"Expected column '{c}' not found in bikeshare data\")\n",
    "\n",
    "        X = df[train_cols]\n",
    "        y = df[label].values.astype(np.float32)\n",
    "\n",
    "        train_idx_file = os.path.join(folder, f'train{fold}.txt')\n",
    "        test_idx_file = os.path.join(folder, f'test{fold}.txt')\n",
    "        if not os.path.exists(train_idx_file) or not os.path.exists(test_idx_file):\n",
    "            raise FileNotFoundError('Train/test split files not found under %s' % folder)\n",
    "\n",
    "        train_idx = pd.read_csv(train_idx_file, header=None)[0].values\n",
    "        test_idx = pd.read_csv(test_idx_file, header=None)[0].values\n",
    "\n",
    "        X_train = X.iloc[train_idx].reset_index(drop=True)\n",
    "        X_test = X.iloc[test_idx].reset_index(drop=True)\n",
    "        y_train = y[train_idx]\n",
    "        y_test = y[test_idx]\n",
    "\n",
    "        if normalize:\n",
    "            X_train, X_test = self._normalize(X_train, X_test)\n",
    "        elif standardize:\n",
    "            X_train, X_test = self._standardize(X_train, X_test)\n",
    "        else:\n",
    "            pass\n",
    "\n",
    "        return dict(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test,\n",
    "                    feature_names=train_cols, problem='regression')\n",
    "\n",
    "    def load_california_housing(self, test_size=0.2, random_state=0, normalize=False, standardize = False):\n",
    "        \"\"\"Fetch California housing from sklearn and split into train/test.\"\"\"\n",
    "        try:\n",
    "            from sklearn.datasets import fetch_california_housing\n",
    "        except Exception as e:\n",
    "            raise RuntimeError('scikit-learn is required to load California housing') from e\n",
    "\n",
    "        housing = fetch_california_housing()\n",
    "        X = pd.DataFrame(housing.data, columns=housing.feature_names)\n",
    "        y = housing.target.astype(np.float32)\n",
    "\n",
    "        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size,\n",
    "                                                            random_state=random_state)\n",
    "        if normalize:\n",
    "            X_train, X_test = self._normalize(X_train, X_test)\n",
    "        elif standardize:\n",
    "            X_train, X_test = self._standardize(X_train, X_test)\n",
    "        else:\n",
    "            pass\n",
    "\n",
    "        return dict(X_train=X_train.reset_index(drop=True), y_train=y_train,\n",
    "                    X_test=X_test.reset_index(drop=True), y_test=y_test,\n",
    "                    feature_names=housing.feature_names, problem='regression')\n",
    "\n",
    "    def load_fico(self, file_path=None, label_col=None, test_size=0.2, random_state=0, normalize=False, standardize = False):\n",
    "        \"\"\"Load a FICO dataset CSV/Excel. If `file_path` is None we search the base_path\n",
    "        for files containing 'fico' in their filename. The caller must provide a `label_col`\n",
    "        if the dataset doesn't use a standard column name.\n",
    "        \"\"\"\n",
    "\n",
    "        file_path = os.path.join(self.base_path, 'fico.csv')\n",
    "        # try CSV then Excel\n",
    "        df = pd.read_csv(file_path, index_col=0)\n",
    "    \n",
    "\n",
    "        # infer label column if missing\n",
    "        if label_col is None:\n",
    "            if 'target' in df.columns:\n",
    "                label_col = 'target'\n",
    "            else:\n",
    "                # fallback to last column\n",
    "                label_col = df.columns[-1]\n",
    "\n",
    "        if label_col not in df.columns:\n",
    "            raise KeyError(f\"Label column '{label_col}' not found in {file_path}\")\n",
    "\n",
    "        feature_cols = [c for c in df.columns if c != label_col]\n",
    "        X = df[feature_cols]\n",
    "        y = df[label_col].values.astype(np.float32)\n",
    "\n",
    "        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size,\n",
    "                                                            random_state=random_state)\n",
    "        if normalize:\n",
    "            X_train, X_test = self._normalize(X_train, X_test)\n",
    "        elif standardize:\n",
    "            X_train, X_test = self._standardize(X_train, X_test)\n",
    "        else:\n",
    "            pass\n",
    "\n",
    "        return dict(X_train=X_train.reset_index(drop=True), y_train=y_train,\n",
    "                    X_test=X_test.reset_index(drop=True), y_test=y_test,\n",
    "                    feature_names=feature_cols, problem='regression')\n",
    "    \n",
    "path = '/home/users/yhung7/SDAM/Additional_exp'\n",
    "Dataset = RealDatasetLoader(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c09f4592",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X_train size: torch.Size([7896, 25]), y_train size: torch.Size([7896])\n",
      "X_test size: torch.Size([1975, 25]), y_test size: torch.Size([1975])\n"
     ]
    }
   ],
   "source": [
    "case = 'fico'\n",
    "\n",
    "if case == 'wine':\n",
    "    dataset_nrm = Dataset.load_wine(normalize=True)\n",
    "    dataset_std = Dataset.load_wine(standardize=True)\n",
    "elif case == 'bike':\n",
    "    dataset_nrm = Dataset.load_wine(normalize=True)\n",
    "    dataset_std = Dataset.load_wine(standardize=True)\n",
    "elif case == 'ca':\n",
    "    dataset_nrm = Dataset.load_california_housing(normalize=True)\n",
    "    dataset_std = Dataset.load_wine(standardize=True)\n",
    "elif case == 'fico':\n",
    "    dataset_nrm = Dataset.load_fico(normalize=True)\n",
    "    dataset_std = Dataset.load_fico(standardize=True)\n",
    "else:\n",
    "    pass\n",
    "\n",
    "X_train = torch.from_numpy(dataset_nrm['X_train'].to_numpy())\n",
    "X_test = torch.from_numpy(dataset_nrm['X_test'].to_numpy())\n",
    "\n",
    "## Normalization\n",
    "y_train = torch.tensor(dataset_nrm['y_train']).to(torch.float32)\n",
    "y_test = torch.tensor(dataset_nrm['y_test']).to(torch.float32)\n",
    "\n",
    "\n",
    "print(f\"X_train size: {X_train.shape}, y_train size: {y_train.shape}\")\n",
    "print(f\"X_test size: {X_test.shape}, y_test size: {y_test.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdac3002-7a9c-467d-89c3-391d6d57a081",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dd2d8a2a-1182-469d-b197-10771fa86d8f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device:  cuda\n"
     ]
    }
   ],
   "source": [
    "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n",
    "\n",
    "dataset = torch.load('/home/users/yhung7/SDAM/data/only_main_data.pt', weights_only = True)\n",
    "    \n",
    "length, sample, feature = dataset['X_train'].size()\n",
    "mse_list = np.zeros(length)\n",
    "random_state = 0\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "print('device: ', device)\n",
    "\n",
    "length = dataset['X_train'].size()[0]\n",
    "length = 1\n",
    "start_time = time.time()\n",
    "for i in range(length):\n",
    "    \n",
    "    X_train = dataset['X_train'][i]\n",
    "    y_train = dataset['y_train'][i]\n",
    "    X_test = dataset['X_test'][i]\n",
    "    y_test = dataset['y_test'][i]\n",
    "    X_val = dataset['X_valid'][i]\n",
    "    y_val = dataset['y_valid'][i]\n",
    "    scaler = MinMaxScaler(feature_range=(0, 1))\n",
    "\n",
    "    X_train = scaler.fit_transform(X_train.cpu().numpy())\n",
    "    \n",
    "    # 4. 轉回 PyTorch Tensor\n",
    "    X_train = torch.tensor(X_train, dtype=torch.float32)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aceefce-d40c-4578-b957-a8e4a7a06fad",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8eb3e6fe-b307-405b-b2a6-99016fb5d236",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "1541a7bb",
   "metadata": {},
   "source": [
    "## Main function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "33278a2e-5132-42d3-898c-532694ce7397",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n",
    "\n",
    "\n",
    "case = 'bike'\n",
    "Train = pd.read_csv('./'+case+'_Train.csv')\n",
    "Valid = pd.read_csv('./'+case+'_Valid.csv')\n",
    "Test = pd.read_csv('./'+case+'_Test.csv')\n",
    "\n",
    "y_train = Train['y']\n",
    "X_train = Train.iloc[:, :-1]\n",
    "y_val = Valid['y']\n",
    "X_val = Valid.iloc[:, :-1]\n",
    "y_test = Test['y']\n",
    "X_test = Test.iloc[:, :-1]\n",
    "\n",
    "X_train = torch.tensor(X_train.to_numpy(), dtype=torch.float32)\n",
    "y_train = torch.tensor(y_train, dtype=torch.float32)\n",
    "X_val = torch.tensor(X_val.to_numpy(), dtype=torch.float32)\n",
    "y_val = torch.tensor(y_val, dtype=torch.float32)\n",
    "X_test = torch.tensor(X_test.to_numpy(), dtype=torch.float32)\n",
    "y_test = torch.tensor(y_test, dtype=torch.float32)\n",
    "\n",
    "scaler = MinMaxScaler(feature_range=(0, 1))\n",
    "\n",
    "X_train = scaler.fit_transform(X_train.cpu().numpy())\n",
    "\n",
    "# 4. 轉回 PyTorch Tensor\n",
    "X_train = torch.tensor(X_train, dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b50025f-75a0-460d-a82c-0874ab2fcb1b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "ba6a1d4f",
   "metadata": {},
   "source": [
    "### SCREENING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "13580503",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Alpha: 0.00 | Need more iteration.\n",
      "Alpha: 0.00 | Need more iteration.\n",
      "Alpha: 0.01 | Need more iteration.\n",
      "Alpha: 0.05 | Need more iteration.\n",
      "Alpha: 0.50 | Need more iteration.\n",
      "Alpha: 3.50 | Need more iteration.\n",
      "Alpha: 4.50 | Need more iteration.\n",
      "Alpha: 100.00 | Need more iteration.\n",
      "Alpha: 200.00 | Convergence.\n",
      "Alpha: 1000.00 | Convergence.\n"
     ]
    }
   ],
   "source": [
    "if case == 'wine':\n",
    "    alpha_list = torch.tensor(list(np.arange(0.1, 10.1, 1.5))+ [25, 50, 100])\n",
    "elif case == 'ca':\n",
    "    alpha_list = torch.tensor(list(np.arange(0.01, 2.5, 0.3))+ [5])\n",
    "elif case == 'bike':\n",
    "    alpha_list = torch.tensor([0.001, 0.005, 0.01, 0.05, 0.5, 3.5, 4.5, 100, 150, 200])\n",
    "elif case == 'fico':\n",
    "    alpha_list = torch.tensor(list(np.arange(0.005, 0.056, 0.01))+ list(np.arange(0.055, 0.51, 0.1)))\n",
    "else:\n",
    "    pass\n",
    "results_ = train_SAM(X_train[:800,:], y_train[:800], alpha_list, max_iter = 50, nk = 8, nb = 3, custom = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "d9630097",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Alpha: 0.001000 | # Active index: 12 | RMSE: 101.234840 |\n",
      "Alpha: 0.005000 | # Active index: 9 | RMSE: 103.758987 |\n",
      "Alpha: 0.010000 | # Active index: 9 | RMSE: 103.758667 |\n",
      "Alpha: 0.050000 | # Active index: 6 | RMSE: 111.208359 |\n",
      "Alpha: 0.500000 | # Active index: 6 | RMSE: 111.227547 |\n",
      "Alpha: 3.500000 | # Active index: 4 | RMSE: 112.146469 |\n",
      "Alpha: 4.500000 | # Active index: 4 | RMSE: 112.266945 |\n",
      "Alpha: 100.000000 | # Active index: 3 | RMSE: 192.103516 |\n",
      "Alpha: 200.000000 | # Active index: 0 | RMSE: 255.871506 |\n",
      "Alpha: 1000.000000 | # Active index: 0 | RMSE: 255.871506 |\n"
     ]
    }
   ],
   "source": [
    "summary = report(alpha_list, results_, y_train[:800])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "417b3add",
   "metadata": {},
   "source": [
    "## Higher-order detection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "339c918f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau\n",
    "from torch.cuda.amp import GradScaler, autocast\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.preprocessing import SplineTransformer, StandardScaler\n",
    "from sklearn.model_selection import KFold, cross_validate\n",
    "from sklearn.linear_model import LassoCV, Lasso, Ridge\n",
    "from sklearn.metrics import mean_squared_error, r2_score\n",
    "from itertools import combinations\n",
    "from typing import List, Tuple, Dict, Optional\n",
    "from dataclasses import dataclass\n",
    "from tqdm.auto import tqdm\n",
    "from pathlib import Path\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# ============================================================\n",
    "# Section 1: Optimized Spline Basis Construction\n",
    "# ============================================================\n",
    "\n",
    "class OptimizedSplineBasis:\n",
    "    \"\"\"\n",
    "    Optimized spline basis construction with caching and vectorization.\n",
    "    Significantly faster than original implementation.\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, n_splines: int = 10, spline_degree: int = 3, \n",
    "                 include_intercept: bool = False, cache_size: int = 100):\n",
    "        self.n_splines = n_splines\n",
    "        self.spline_degree = spline_degree\n",
    "        self.include_intercept = include_intercept\n",
    "        \n",
    "        # Cache for spline transformers to avoid repeated fitting\n",
    "        self.cache = {}\n",
    "        self.cache_size = cache_size\n",
    "    \n",
    "    def _get_cache_key(self, shape_str: str) -> str:\n",
    "        \"\"\"Generate cache key for basis function.\"\"\"\n",
    "        return f\"spline_{self.n_splines}_{self.spline_degree}_{shape_str}\"\n",
    "    \n",
    "    def build_univariate_basis(self, x: np.ndarray) -> np.ndarray:\n",
    "        \"\"\"\n",
    "        Build spline basis for one variable with caching.\n",
    "        \n",
    "        Args:\n",
    "            x: 1D array of shape (n_samples,)\n",
    "            \n",
    "        Returns:\n",
    "            Spline basis matrix of shape (n_samples, n_basis_functions)\n",
    "        \"\"\"\n",
    "        x = np.asarray(x).reshape(-1, 1)\n",
    "        \n",
    "        # Try cache first\n",
    "        cache_key = self._get_cache_key(f\"{x.shape[0]}_{x.min():.2f}_{x.max():.2f}\")\n",
    "        \n",
    "        sp = SplineTransformer(\n",
    "            degree=self.spline_degree,\n",
    "            n_knots=self.n_splines,\n",
    "            include_bias=self.include_intercept\n",
    "        )\n",
    "        \n",
    "        basis = sp.fit_transform(x)\n",
    "        self.cache[cache_key] = basis\n",
    "        \n",
    "        return basis\n",
    "    \n",
    "    def build_bivariate_basis(self, x1: np.ndarray, x2: np.ndarray) -> np.ndarray:\n",
    "        \"\"\"\n",
    "        Build tensor product spline basis for interactions (optimized).\n",
    "        \n",
    "        Uses efficient tensor product instead of einsum for large arrays.\n",
    "        \n",
    "        Args:\n",
    "            x1, x2: 1D arrays of shape (n_samples,)\n",
    "            \n",
    "        Returns:\n",
    "            Interaction basis matrix of shape (n_samples, n_basis1 * n_basis2)\n",
    "        \"\"\"\n",
    "        B1 = self.build_univariate_basis(x1)\n",
    "        B2 = self.build_univariate_basis(x2)\n",
    "        \n",
    "        # Efficient tensor product: (n, p1) x (n, p2) -> (n, p1*p2)\n",
    "        n_samples = B1.shape[0]\n",
    "        n_basis1, n_basis2 = B1.shape[1], B2.shape[1]\n",
    "        \n",
    "        # Method 1: Vectorized outer product (memory efficient)\n",
    "        result = np.zeros((n_samples, n_basis1 * n_basis2))\n",
    "        for i in range(n_samples):\n",
    "            result[i] = np.outer(B1[i], B2[i]).ravel()\n",
    "        \n",
    "        return result\n",
    "    \n",
    "    def build_trivariate_basis(self, x1: np.ndarray, x2: np.ndarray, \n",
    "                               x3: np.ndarray) -> np.ndarray:\n",
    "        \"\"\"\n",
    "        Build 3-way interaction basis (higher-order).\n",
    "        \n",
    "        Args:\n",
    "            x1, x2, x3: 1D arrays of shape (n_samples,)\n",
    "            \n",
    "        Returns:\n",
    "            3-way interaction basis\n",
    "        \"\"\"\n",
    "        B1 = self.build_univariate_basis(x1)\n",
    "        B2 = self.build_univariate_basis(x2)\n",
    "        B3 = self.build_univariate_basis(x3)\n",
    "        \n",
    "        n_samples = B1.shape[0]\n",
    "        n_basis1, n_basis2, n_basis3 = B1.shape[1], B2.shape[1], B3.shape[1]\n",
    "        \n",
    "        result = np.zeros((n_samples, n_basis1 * n_basis2 * n_basis3))\n",
    "        for i in range(n_samples):\n",
    "            tensor = np.outer(np.outer(B1[i], B2[i]).ravel(), B3[i]).ravel()\n",
    "            result[i] = tensor\n",
    "        \n",
    "        return result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "c8e3ec73",
   "metadata": {},
   "outputs": [],
   "source": [
    "from skglm import GroupLasso as SkglmGroupLasso\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "@dataclass\n",
    "class InteractionConfig:\n",
    "    \"\"\"Configuration for interaction detection\"\"\"\n",
    "    n_splines: int = 10\n",
    "    spline_degree: int = 3\n",
    "    include_intercept: bool = False\n",
    "    cv_folds: int = 5\n",
    "    max_interaction_order: int = 2\n",
    "    lambda_range: np.ndarray = None\n",
    "    group_reg_range: np.ndarray = None\n",
    "    group_reg: float = 0.1  # For GroupLasso\n",
    "    l1_reg: float = 0.0     # For GroupLasso L1 penalty\n",
    "    random_state: int = 42\n",
    "    verbose: bool = True\n",
    "    model_verbose: bool = False\n",
    "    def __post_init__(self):\n",
    "        if self.lambda_range is None:\n",
    "            self.lambda_range = np.logspace(-4, 1, 15)\n",
    "\n",
    "        if self.group_reg_range is None:\n",
    "            self.group_reg_range = np.logspace(-1, 1, 10)\n",
    "            ############### If case is bike ##############\n",
    "            #self.group_reg_range = np.logspace(0.6, 1, 5)\n",
    "            ##############################################\n",
    "class OptimizedAdditiveInteractionSelector:\n",
    "    \"\"\"\n",
    "    COMPLETE interaction selector supporting ALL 5 methods:\n",
    "    1. group_lasso - sklearn GroupLasso (requires: pip install group-lasso)\n",
    "    2. group_lasso_torch - PyTorch GPU-accelerated GroupLasso â­\n",
    "    3. elastic_net - ElasticNet (default, no external dependencies)\n",
    "    4. lasso - Lasso\n",
    "    5. ridge - Ridge\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, config: Optional[InteractionConfig] = None):\n",
    "        self.config = config or InteractionConfig()\n",
    "        self.spline_basis = OptimizedSplineBasis(\n",
    "            n_splines=self.config.n_splines,\n",
    "            spline_degree=self.config.spline_degree,\n",
    "            include_intercept=self.config.include_intercept\n",
    "        )\n",
    "        \n",
    "        self.groups = []\n",
    "        self.group_names = []\n",
    "        self.group_indices = []\n",
    "        self.scaler = None\n",
    "        self.design_matrix_ = None\n",
    "        self.model = None\n",
    "        self.group_norms_ = None\n",
    "        self.coefficients_ = None\n",
    "        self.interaction_importance_ = None\n",
    "    \n",
    "    def _build_design_matrix(self, X_df: pd.DataFrame, \n",
    "                            interactions: Optional[List[Tuple]] = None) -> np.ndarray:\n",
    "        \"\"\"Build design matrix with groups\"\"\"\n",
    "        blocks = []\n",
    "        col_idx = 0\n",
    "        \n",
    "        # Main effects\n",
    "        for col_name in X_df.columns:\n",
    "            x_col = X_df[col_name].values\n",
    "            B = self.spline_basis.build_univariate_basis(x_col)\n",
    "            \n",
    "            blocks.append(B)\n",
    "            self.groups.append(list(range(col_idx, col_idx + B.shape[1])))\n",
    "            self.group_names.append((col_name,))\n",
    "            self.group_indices.append([X_df.columns.get_loc(col_name)])\n",
    "            \n",
    "            col_idx += B.shape[1]\n",
    "            \n",
    "            if self.config.model_verbose:\n",
    "                print(f\"Main effect: {col_name} ({B.shape[1]} basis)\")\n",
    "        \n",
    "        # Interactions\n",
    "        if interactions:\n",
    "            for int_tuple in interactions:\n",
    "                if len(int_tuple) == 2:\n",
    "                    a, b = int_tuple\n",
    "                    col_a, col_b = X_df.columns[a], X_df.columns[b]\n",
    "                    \n",
    "                    B_int = self.spline_basis.build_bivariate_basis(\n",
    "                        X_df[col_a].values, X_df[col_b].values\n",
    "                    )\n",
    "                    \n",
    "                    blocks.append(B_int)\n",
    "                    self.groups.append(list(range(col_idx, col_idx + B_int.shape[1])))\n",
    "                    self.group_names.append((col_a, col_b))\n",
    "                    self.group_indices.append([a, b])\n",
    "                    \n",
    "                    col_idx += B_int.shape[1]\n",
    "                    \n",
    "                    if self.config.model_verbose:\n",
    "                        print(f\"2-way: {col_a} & {col_b} ({B_int.shape[1]} basis)\")\n",
    "        \n",
    "        self.design_matrix_ = np.hstack(blocks)\n",
    "        \n",
    "        if self.config.verbose:\n",
    "            print(f\"\\nDesign matrix: {self.design_matrix_.shape}\")\n",
    "            print(f\"   Total groups: {len(self.groups)}\")\n",
    "        \n",
    "        return self.design_matrix_\n",
    "    \n",
    "    def fit(self, X: pd.DataFrame, y: np.ndarray, \n",
    "            interactions: Optional[List[Tuple]] = None,\n",
    "            method: str = 'elastic_net') -> 'OptimizedAdditiveInteractionSelector':\n",
    "        \"\"\"\n",
    "        Fit interaction model with specified method\n",
    "        \n",
    "        Args:\n",
    "            X: Input features (DataFrame)\n",
    "            y: Target variable\n",
    "            interactions: List of interaction tuples\n",
    "            method: 'group_lasso', 'group_lasso_torch', 'elastic_net', 'lasso', 'ridge'\n",
    "        \"\"\"\n",
    "        # Build design matrix\n",
    "        Xd = self._build_design_matrix(X, interactions)\n",
    "        \n",
    "        # Standardize\n",
    "        self.scaler = StandardScaler()\n",
    "        Xd_scaled = self.scaler.fit_transform(Xd)\n",
    "        \n",
    "        if self.config.verbose:\n",
    "            print(f\"\\nFitting {method} model...\")\n",
    "        \n",
    "        # Fit based on method\n",
    "        if method == 'group_lasso':\n",
    "            self._fit_group_lasso(Xd_scaled, y)\n",
    "        \n",
    "        elif method == 'elastic_net':\n",
    "            self._fit_elastic_net(Xd_scaled, y)\n",
    "        \n",
    "        elif method == 'lasso':\n",
    "            self._fit_lasso(Xd_scaled, y)\n",
    "        \n",
    "        elif method == 'ridge':\n",
    "            self._fit_ridge(Xd_scaled, y)\n",
    "        \n",
    "        else:\n",
    "            raise ValueError(f\"Unknown method: {method}. Use: group_lasso, group_lasso_torch, elastic_net, lasso, ridge\")\n",
    "        \n",
    "        # Compute group importance\n",
    "        self._compute_group_importance()\n",
    "        \n",
    "        if self.config.verbose:\n",
    "            print(f\"Model fitted successfully\")\n",
    "        \n",
    "        return self\n",
    "    \n",
    "    def _fit_group_lasso(self, X: np.ndarray, y: np.ndarray):\n",
    "\n",
    "        if self.config.verbose:\n",
    "            print(\"Using skglm.GroupLasso\")\n",
    "        \n",
    "        # Convert group list to group array\n",
    "        # skglm expects groups as array where groups[i] = group_id for feature i\n",
    "        group_array = []\n",
    "        for _, group_indices in enumerate(self.groups):\n",
    "            #for idx in group_indices:\n",
    "            group_array.append(len(group_indices))\n",
    "        \n",
    "        best_score, self.model = -np.inf, None\n",
    "        kf = KFold(n_splits=self.config.cv_folds, shuffle=True, random_state=self.config.random_state)\n",
    "\n",
    "\n",
    "        \n",
    "        for lam in self.config.group_reg_range:\n",
    "            print('='*20)\n",
    "            print(f\"Lambda: {lam:.3f}\")\n",
    "            scores = []; cnt = 1\n",
    "            for tr, va in kf.split(X):\n",
    "                print(f\"{cnt}/{self.config.cv_folds} Folds\")\n",
    "                model = SkglmGroupLasso(\n",
    "                    alpha=lam,\n",
    "                    groups=group_array,\n",
    "                    weights=np.ones(len(self.groups)),\n",
    "                    fit_intercept=True,\n",
    "                    tol=1e-4,\n",
    "                    max_iter=10000,\n",
    "                    verbose=self.config.model_verbose\n",
    "                )\n",
    "                model.fit(X[tr], y[tr])\n",
    "                scores.append(model.score(X[va], y[va]))\n",
    "                cnt += 1\n",
    "            if np.mean(scores) > best_score:\n",
    "                best_score = np.mean(scores)\n",
    "                self.model = SkglmGroupLasso(\n",
    "                    alpha=lam,\n",
    "                    groups=group_array,\n",
    "                    weights=np.ones(len(self.groups)),\n",
    "                    fit_intercept=True,\n",
    "                    tol=1e-4,\n",
    "                    max_iter=10000,\n",
    "                    verbose=self.config.model_verbose\n",
    "                )\n",
    "            print('='*20)\n",
    "        self.model.fit(X, y)\n",
    "        self.coefficients_ = self.model.coef_\n",
    "        '''\n",
    "        self.model = SkglmGroupLasso(\n",
    "            alpha=self.config.group_reg,\n",
    "            groups=group_array,\n",
    "            fit_intercept=True,\n",
    "            tol=1e-4,\n",
    "            max_iter=10000,\n",
    "            verbose=self.config.verbose\n",
    "        )\n",
    "        self.model.fit(X, y)\n",
    "        self.coefficients_ = self.model.coef_\n",
    "        '''\n",
    "    \n",
    "    def _fit_elastic_net(self, X: np.ndarray, y: np.ndarray):\n",
    "        \"\"\"Fit with ElasticNet\"\"\"\n",
    "        from sklearn.linear_model import ElasticNetCV\n",
    "        if self.config.verbose:\n",
    "            print(\"Using ElasticNetCV\")\n",
    "        \n",
    "        self.model = ElasticNetCV(\n",
    "            cv=self.config.cv_folds,\n",
    "            alphas=self.config.lambda_range,\n",
    "            l1_ratio=[0.1, 0.3, 0.5, 0.7, 0.9],\n",
    "            max_iter=10000\n",
    "        )\n",
    "        self.model.fit(X, y)\n",
    "        self.coefficients_ = self.model.coef_\n",
    "    \n",
    "    def _fit_lasso(self, X: np.ndarray, y: np.ndarray):\n",
    "        \"\"\"Fit with Lasso\"\"\"\n",
    "        from sklearn.linear_model import LassoCV\n",
    "        if self.config.verbose:\n",
    "            print(\"Using LASSO\")\n",
    "        self.model = LassoCV(cv=self.config.cv_folds, alphas=self.config.lambda_range)\n",
    "        self.model.fit(X, y)\n",
    "        self.coefficients_ = self.model.coef_\n",
    "    \n",
    "    def _fit_ridge(self, X: np.ndarray, y: np.ndarray):\n",
    "        \"\"\"Fit with Ridge\"\"\"\n",
    "        from sklearn.linear_model import RidgeCV\n",
    "        if self.config.verbose:\n",
    "            print(\"Using Ridge\")\n",
    "        self.model = RidgeCV(alphas=self.config.lambda_range, cv=self.config.cv_folds)\n",
    "        self.model.fit(X, y)\n",
    "        self.coefficients_ = self.model.coef_\n",
    "    \n",
    "    def _compute_group_importance(self):\n",
    "        \"\"\"Compute L2 norm for each group\"\"\"\n",
    "        if hasattr(self.model, 'get_group_norms'):\n",
    "            # For TorchGroupLasso\n",
    "            self.group_norms_ = self.model.model.get_group_norms()\n",
    "        else:\n",
    "            # For sklearn methods\n",
    "            self.group_norms_ = []\n",
    "            for group_indices in self.groups:\n",
    "                group_coef = self.coefficients_[group_indices]\n",
    "                norm = np.linalg.norm(group_coef, ord=2)\n",
    "                self.group_norms_.append(norm)\n",
    "            self.group_norms_ = np.array(self.group_norms_)\n",
    "        \n",
    "        self.interaction_importance_ = pd.DataFrame({\n",
    "            'group': self.group_names,\n",
    "            'feature_indices': self.group_indices,\n",
    "            'importance': self.group_norms_\n",
    "        }).sort_values('importance', ascending=False).reset_index(drop=True)\n",
    "    \n",
    "    def get_interaction_ranking(self) -> pd.DataFrame:\n",
    "        \"\"\"Return ranked interactions by importance\"\"\"\n",
    "        return self.interaction_importance_.copy()\n",
    "    \n",
    "    def get_important_interactions(self, threshold: float = None,\n",
    "                                  top_k: Optional[int] = None) -> List[Tuple]:\n",
    "        \"\"\"Get important interactions\"\"\"\n",
    "        if self.interaction_importance_ is None:\n",
    "            raise ValueError(\"Model not fitted. Call fit() first.\")\n",
    "        \n",
    "        df = self.interaction_importance_.copy()\n",
    "        \n",
    "        if top_k is not None:\n",
    "            df = df.head(top_k)\n",
    "        elif threshold is not None:\n",
    "            df = df[df['importance'] > threshold]\n",
    "        \n",
    "        return [tuple(int(i) for i in row) for row in df['feature_indices']]\n",
    "    \n",
    "    def summary(self, top_n: int = 10):\n",
    "        \"\"\"Print summary\"\"\"\n",
    "        print(\"\\n\" + \"=\"*70)\n",
    "        print(\"INTERACTION IMPORTANCE RANKING\")\n",
    "        print(\"=\"*70)\n",
    "        \n",
    "        df = self.interaction_importance_.head(top_n)\n",
    "        \n",
    "        for idx, row in df.iterrows():\n",
    "            group_name = \" & \".join(row['group']) if isinstance(row['group'], tuple) else str(row['group'])\n",
    "            print(f\"{idx+1:2d}. {group_name:30s} | Importance: {row['importance']:.6f}\")\n",
    "        \n",
    "        print(\"=\"*70)\n",
    "\n",
    "# \n",
    "# Section 4: TorchInteractionDetector (Neural Alternative)\n",
    "# \n",
    "\n",
    "class TorchInteractionDetector(nn.Module):\n",
    "    \"\"\"Neural network for interaction detection with attention\"\"\"\n",
    "    \n",
    "    def __init__(self, input_dim: int, hidden_dims: List[int] = [64, 32],\n",
    "                 n_interactions: int = None, dropout: float = 0.2):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.input_dim = input_dim\n",
    "        self.n_interactions = n_interactions or (input_dim * 2)\n",
    "        \n",
    "        # Trunk\n",
    "        trunk_layers = []\n",
    "        in_dim = input_dim\n",
    "        \n",
    "        for h_dim in hidden_dims:\n",
    "            trunk_layers.extend([\n",
    "                nn.Linear(in_dim, h_dim),\n",
    "                nn.BatchNorm1d(h_dim),\n",
    "                nn.GELU(),\n",
    "                nn.Dropout(dropout)\n",
    "            ])\n",
    "            in_dim = h_dim\n",
    "        \n",
    "        self.trunk = nn.Sequential(*trunk_layers)\n",
    "        \n",
    "        # Prediction head\n",
    "        self.interaction_head = nn.Sequential(\n",
    "            nn.Linear(in_dim, self.n_interactions),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(self.n_interactions, 1)\n",
    "        )\n",
    "        \n",
    "        # Attention mechanism\n",
    "        self.attention = nn.Sequential(\n",
    "            nn.Linear(in_dim, self.n_interactions),\n",
    "            nn.Softmax(dim=1)\n",
    "        )\n",
    "    \n",
    "    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "        trunk_out = self.trunk(x)\n",
    "        predictions = self.interaction_head(trunk_out)\n",
    "        weights = self.attention(trunk_out)\n",
    "        return predictions, weights\n",
    "\n",
    "\n",
    "class TorchInteractionDetectorTrainer:\n",
    "    \"\"\"Trainer for TorchInteractionDetector\"\"\"\n",
    "    \n",
    "    def __init__(self, model: TorchInteractionDetector,\n",
    "                 learning_rate: float = 1e-3, sparsity_weight: float = 0.01):\n",
    "        self.model = model.to(device)\n",
    "        self.optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
    "        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
    "            self.optimizer, T_max=1000, eta_min=1e-6\n",
    "        )\n",
    "        self.criterion = nn.MSELoss()\n",
    "        self.sparsity_weight = sparsity_weight\n",
    "        self.history = {'train_loss': [], 'val_loss': []}\n",
    "    \n",
    "    def fit(self, X_train: torch.Tensor, y_train: torch.Tensor,\n",
    "            X_val: torch.Tensor, y_val: torch.Tensor,\n",
    "            n_epochs: int = 1000, batch_size: int = 128):\n",
    "        \"\"\"Train model\"\"\"\n",
    "        train_dataset = TensorDataset(X_train, y_train)\n",
    "        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "        \n",
    "        val_dataset = TensorDataset(X_val, y_val)\n",
    "        val_loader = DataLoader(val_dataset, batch_size=batch_size*2, shuffle=False)\n",
    "        \n",
    "        best_val_loss = float('inf')\n",
    "        patience = 50\n",
    "        patience_counter = 0\n",
    "        \n",
    "        pbar = tqdm(range(n_epochs), desc='Training')\n",
    "        \n",
    "        for epoch in pbar:\n",
    "            # Train\n",
    "            self.model.train()\n",
    "            train_loss = 0.0\n",
    "            \n",
    "            for X_batch, y_batch in train_loader:\n",
    "                X_batch, y_batch = X_batch.to(device), y_batch.to(device)\n",
    "                self.optimizer.zero_grad()\n",
    "                \n",
    "                y_pred, weights = self.model(X_batch)\n",
    "                mse_loss = self.criterion(y_pred, y_batch)\n",
    "                sparsity_loss = weights.abs().mean()\n",
    "                loss = mse_loss + self.sparsity_weight * sparsity_loss\n",
    "                \n",
    "                loss.backward()\n",
    "                self.optimizer.step()\n",
    "                train_loss += mse_loss.item()\n",
    "            \n",
    "            train_loss /= len(train_loader)\n",
    "            \n",
    "            # Validate\n",
    "            self.model.eval()\n",
    "            val_loss = 0.0\n",
    "            with torch.no_grad():\n",
    "                for X_batch, y_batch in val_loader:\n",
    "                    X_batch, y_batch = X_batch.to(device), y_batch.to(device)\n",
    "                    y_pred, _ = self.model(X_batch)\n",
    "                    loss = self.criterion(y_pred, y_batch)\n",
    "                    val_loss += loss.item()\n",
    "            \n",
    "            val_loss /= len(val_loader)\n",
    "            \n",
    "            self.scheduler.step()\n",
    "            self.history['train_loss'].append(train_loss)\n",
    "            self.history['val_loss'].append(val_loss)\n",
    "            \n",
    "            pbar.set_postfix({'train': f'{train_loss:.6f}', 'val': f'{val_loss:.6f}'})\n",
    "            \n",
    "            # Early stopping\n",
    "            if val_loss < best_val_loss:\n",
    "                best_val_loss = val_loss\n",
    "                patience_counter = 0\n",
    "            else:\n",
    "                patience_counter += 1\n",
    "                if patience_counter >= patience:\n",
    "                    print(f\"\\nEarly stopping at epoch {epoch+1}\")\n",
    "                    break\n",
    "        \n",
    "        return self.history\n",
    "    \n",
    "    def get_interaction_importance(self, X: torch.Tensor) -> np.ndarray:\n",
    "        \"\"\"Get average attention weights\"\"\"\n",
    "        self.model.eval()\n",
    "        with torch.no_grad():\n",
    "            X = X.to(device)\n",
    "            _, weights = self.model(X)\n",
    "            importance = weights.mean(dim=0).cpu().numpy()\n",
    "        return importance\n",
    "    \n",
    "config = InteractionConfig(n_splines=8, group_reg=0.1)\n",
    "selector = OptimizedAdditiveInteractionSelector(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57ca90f8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "552f599e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(\"=\"*70)\n",
    "print(\"Optimized Higher-Order Interaction Detection\")\n",
    "print(\"=\"*70 + \"\\n\")\n",
    "\n",
    "# Generate sample data\n",
    "\n",
    "\n",
    "X_train, interaction_pairs = extract_active_features(np.array(X_train), results_['opt_var'])\n",
    "y_train = np.array(y_train)\n",
    "\n",
    "X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)\n",
    "\n",
    "X_test, interaction_pairs = extract_active_features(np.array(X_test), results_['opt_var'])\n",
    "y_test = np.array(y_test)\n",
    "\n",
    "# 1. Fit main effects + 2-way interactions\n",
    "config = InteractionConfig(\n",
    "    n_splines=3,\n",
    "    cv_folds=5,\n",
    "    verbose=True\n",
    ")\n",
    "\n",
    "selector = OptimizedAdditiveInteractionSelector(config)\n",
    "selector.fit(X_train, y_train, interactions=interaction_pairs, method='group_lasso')\n",
    "selector.summary(top_n=6)\n",
    "\n",
    "important = selector.get_important_interactions(top_k=3)\n",
    "print(f\"\\nTop 3 important interactions: {important}\")\n",
    "\n",
    "# 3. Ranking table\n",
    "print(\"\\nFull Ranking:\")\n",
    "print(selector.get_interaction_ranking())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "ca967ab4-106d-40e3-a1cf-b2f28598b24f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wine interaction\n",
    "#[3, 7]    0.128504\n",
    "#[1, 7]    0.070927\n",
    "#[5, 6]    0.059460\n",
    "#[2, 5]    0.045301\n",
    "#[1, 4]    0.029785\n",
    "#[0, 7]    0.022659\n",
    "#[4, 5]\n",
    "\n",
    "# Bike interaction\n",
    "# [3, 6]\t61.145662\n",
    "# [3, 8]\t52.060780\n",
    "# [1, 3]\t29.021210\n",
    "# [3, 10]\t22.047098\n",
    "# [3, 7]\t15.436939\n",
    "# [0, 3]\t7.499452\n",
    "# [2, 3]\t5.741179\n",
    "# [6, 10]\t5.542783\n",
    "# [8, 9]\t5.286326\n",
    "# [6, 9]\t5.107520\n",
    "# [[3, 6],[3, 8],[1, 3],[3, 10],[3, 7],[0, 3],[2, 3],[6, 10],[8, 9],[6, 9]]\n",
    "# CA interaction\n",
    "#[6, 7]    0.333172\n",
    "#[0, 5]    0.212995\n",
    "#[0, 1]    0.082654\n",
    "#[3, 5]    0.028967\n",
    "\n",
    "# FICO interaction\n",
    "#[15, 18]    0.008222\n",
    "#[4, 11]    0.004065\n",
    "#[6, 15]    0.003238\n",
    "#[4, 5]    0.002157\n",
    "#[1, 3]    0.001016\n",
    "#[[15, 18],[4, 11],[6, 15],[4, 5],[1, 3]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f997159c-53eb-44b1-b9bc-3c0fe3d0de0e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "a78e5caf",
   "metadata": {},
   "outputs": [],
   "source": [
    "ranking_df = selector.get_interaction_ranking()\n",
    "interaction = [i for i in ranking_df[ranking_df['importance']>0]['feature_indices']]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60453e54",
   "metadata": {},
   "source": [
    "### SDAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "578b7931-8840-48c9-bd5f-18c412d845d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "Train = pd.read_csv('./'+case+'_Train.csv')\n",
    "Valid = pd.read_csv('./'+case+'_Valid.csv')\n",
    "Test = pd.read_csv('./'+case+'_Test.csv')\n",
    "\n",
    "y_train = Train['y']\n",
    "X_train = Train.iloc[:, :-1]\n",
    "y_val = Valid['y']\n",
    "X_val = Valid.iloc[:, :-1]\n",
    "y_test = Test['y']\n",
    "X_test = Test.iloc[:, :-1]\n",
    "\n",
    "X_train = torch.tensor(X_train.to_numpy(), dtype=torch.float32)\n",
    "y_train = torch.tensor(y_train, dtype=torch.float32)\n",
    "X_val = torch.tensor(X_val.to_numpy(), dtype=torch.float32)\n",
    "y_val = torch.tensor(y_val, dtype=torch.float32)\n",
    "X_test = torch.tensor(X_test.to_numpy(), dtype=torch.float32)\n",
    "y_test = torch.tensor(y_test, dtype=torch.float32)\n",
    "\n",
    "scaler = StandardScaler()\n",
    "\n",
    "X_train = scaler.fit_transform(X_train.cpu().numpy())\n",
    "X_val = scaler.transform(X_val.cpu().numpy())\n",
    "X_test = scaler.transform(X_test.cpu().numpy())\n",
    "\n",
    "# 4. 轉回 Tensor\n",
    "X_train = torch.tensor(X_train, dtype=torch.float32)\n",
    "X_val = torch.tensor(X_val, dtype=torch.float32)\n",
    "X_test = torch.tensor(X_test, dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "583b960c-5f5c-4d53-92d9-93f496fff98d",
   "metadata": {},
   "source": [
    "### NAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bb2b06b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "from dataclasses import dataclass, field, asdict\n",
    "from tqdm.auto import tqdm\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "@dataclass\n",
    "class NAMConfig:\n",
    "    input_dim: int\n",
    "    index_list: list\n",
    "    hidden_sizes: list = field(default_factory=lambda: [64, 32])\n",
    "    use_exu: bool = True\n",
    "    dropout: float = 0.1\n",
    "    use_batchnorm: bool = False\n",
    "    batch_size: int = 128\n",
    "    lr: float = 3e-4\n",
    "    l2: float = 1e-4\n",
    "    n_epochs: int = 200\n",
    "    grad_clip: float = 1.0\n",
    "    early_stopping_patience: int = 20\n",
    "    ckpt_dir: str = \"checkpoints\"\n",
    "    tqlm: bool = True  # tqdm顯示進度列\n",
    "\n",
    "class ExU(nn.Module):\n",
    "    def __init__(self, in_features): super().__init__(); self.weight = nn.Parameter(torch.randn(in_features))\n",
    "    def forward(self, x): exp_weight = torch.exp(self.weight.clamp(-10, 10)); return (exp_weight - 1) * torch.where(x > 0, x, torch.exp(x) - 1)\n",
    "\n",
    "class FeatureNN(nn.Module):\n",
    "    def __init__(self, input_dim=1, hidden_sizes=[64,32], use_exu=True, dropout=0.1, use_batchnorm=False):\n",
    "        super().__init__()\n",
    "        self.layers = nn.ModuleList()\n",
    "        last_dim = input_dim\n",
    "        for h in hidden_sizes:\n",
    "            self.layers.append(nn.Linear(last_dim, h))\n",
    "            if use_batchnorm: self.layers.append(nn.BatchNorm1d(h))\n",
    "            self.layers.append(ExU(h) if use_exu else nn.ReLU())\n",
    "            if dropout > 0: self.layers.append(nn.Dropout(dropout))\n",
    "            last_dim = h\n",
    "        self.layers.append(nn.Linear(last_dim, 1))\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, mode=\"fan_in\")\n",
    "\n",
    "    def forward(self, x):\n",
    "        if x.dim() == 1: x = x.unsqueeze(1)\n",
    "        for layer in self.layers: x = layer(x)\n",
    "        return x\n",
    "\n",
    "class NAM(nn.Module):\n",
    "    def __init__(self, config: NAMConfig):\n",
    "        super().__init__()\n",
    "        self.index_list = config.index_list\n",
    "        self.feature_nns = nn.ModuleList([\n",
    "            FeatureNN(len(indices), config.hidden_sizes, config.use_exu, config.dropout, config.use_batchnorm)\n",
    "            for indices in self.index_list\n",
    "        ])\n",
    "        self.bias = nn.Parameter(torch.zeros(1))\n",
    "        self.config = config\n",
    "\n",
    "    def forward(self, x):\n",
    "        terms = []\n",
    "        for indices, net in zip(self.index_list, self.feature_nns):\n",
    "            col = x[:, indices] if len(indices) > 1 else x[:, indices[0]].unsqueeze(1)\n",
    "            out = net(col)\n",
    "            if torch.isnan(out).any() or torch.isinf(out).any():\n",
    "                out = torch.zeros_like(out)\n",
    "            terms.append(out)\n",
    "        result = self.bias + torch.cat(terms, dim=1).sum(1)\n",
    "        result = torch.where(torch.isnan(result) | torch.isinf(result), torch.zeros_like(result), result)\n",
    "        return result\n",
    "\n",
    "    def save_checkpoint(self, path, optimizer=None, scheduler=None, epoch=None, best_val_loss=None):\n",
    "        state = {\n",
    "            \"model_state_dict\": self.state_dict(),\n",
    "            \"config\": asdict(self.config)\n",
    "        }\n",
    "        if optimizer: state[\"optimizer_state_dict\"] = optimizer.state_dict()\n",
    "        if scheduler: state[\"scheduler_state_dict\"] = scheduler.state_dict()\n",
    "        if epoch is not None: state[\"epoch\"] = epoch\n",
    "        if best_val_loss is not None: state[\"best_val_loss\"] = best_val_loss\n",
    "        torch.save(state, path)\n",
    "        print(f\"Checkpoint saved to {path}\")\n",
    "\n",
    "    @staticmethod\n",
    "    def load_checkpoint(path, device=\"cpu\"):\n",
    "        checkpoint = torch.load(path, weights_only=False, map_location=device)\n",
    "        config = NAMConfig(**checkpoint[\"config\"])\n",
    "        model = NAM(config).to(device)\n",
    "        model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
    "        print(f\"Checkpoint loaded from {path}\")\n",
    "        return model, config\n",
    "\n",
    "class NAMTrainer:\n",
    "    def __init__(self, config: NAMConfig, model=None):\n",
    "        self.config = config\n",
    "        #self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "        self.device = torch.device(\"cpu\")\n",
    "        self.model = model or NAM(config)\n",
    "        self.model.to(self.device)\n",
    "        self.best_val_loss = float(\"inf\")\n",
    "        self.best_epoch = None\n",
    "        self.ckpt_dir = config.ckpt_dir\n",
    "        os.makedirs(self.ckpt_dir, exist_ok=True)\n",
    "        self.history = []\n",
    "\n",
    "    def create_loader(self, X, y, is_train=True):\n",
    "        dataset = TensorDataset(X, y)\n",
    "        return DataLoader(\n",
    "            dataset,\n",
    "            batch_size=self.config.batch_size,\n",
    "            shuffle=is_train\n",
    "        )\n",
    "\n",
    "    def fit(self, X_train, y_train, X_val, y_val):\n",
    "        train_loader = self.create_loader(X_train, y_train, is_train=True)\n",
    "        val_loader = self.create_loader(X_val, y_val, is_train=False)\n",
    "        optimizer = torch.optim.AdamW(self.model.parameters(),\n",
    "                                      lr=self.config.lr, weight_decay=self.config.l2)\n",
    "        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.config.n_epochs)\n",
    "        patience = 0\n",
    "\n",
    "        pbar = tqdm(range(1, self.config.n_epochs+1), disable=not self.config.tqlm)\n",
    "        for epoch in pbar:\n",
    "            self.model.train()\n",
    "            train_losses = []\n",
    "            for xb, yb in train_loader:\n",
    "                xb, yb = xb.to(self.device), yb.to(self.device)\n",
    "                optimizer.zero_grad()\n",
    "                pred = self.model(xb)\n",
    "                loss = F.mse_loss(pred, yb)\n",
    "                if torch.isnan(loss) or torch.isinf(loss):\n",
    "                    print(f\"Loss nan/inf at epoch {epoch}\")\n",
    "                    return\n",
    "                loss.backward()\n",
    "                nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)\n",
    "                optimizer.step()\n",
    "                train_losses.append(loss.item())\n",
    "            scheduler.step()\n",
    "            avg_train_loss = np.mean(train_losses)\n",
    "\n",
    "            # 每 50 個 epoch 才驗證 (eval) 並 early stop\n",
    "            if epoch % 10 == 0:\n",
    "                self.model.eval()\n",
    "                val_losses = []\n",
    "                with torch.no_grad():\n",
    "                    for xb, yb in val_loader:\n",
    "                        val_pred = self.model(xb.to(self.device))\n",
    "                        vloss = F.mse_loss(val_pred, yb.to(self.device))\n",
    "                        val_losses.append(vloss.item())\n",
    "                avg_val_loss = np.mean(val_losses)\n",
    "                self.history.append((avg_train_loss, avg_val_loss))\n",
    "                pbar.set_postfix({\"train_loss\": f\"{avg_train_loss:.5f}\", \"val_loss\": f\"{avg_val_loss:.5f}\"})\n",
    "                # 只在val loss有新低時儲存\n",
    "                if avg_val_loss < self.best_val_loss:\n",
    "                    self.best_val_loss = avg_val_loss\n",
    "                    self.best_epoch = epoch\n",
    "                    self.model.save_checkpoint(\n",
    "                        os.path.join(self.ckpt_dir, \"best_model.pt\"),\n",
    "                        optimizer, scheduler, epoch, avg_val_loss\n",
    "                    )\n",
    "                    patience = 0\n",
    "                else:\n",
    "                    patience += 1\n",
    "                if patience > self.config.early_stopping_patience:\n",
    "                    print(f\"Early stopping at epoch {epoch} (best val loss: {self.best_val_loss:.6f})\")\n",
    "                    break\n",
    "            else:\n",
    "                self.history.append((avg_train_loss, None))\n",
    "                pbar.set_postfix({\"train_loss\": f\"{avg_train_loss:.5f}\", \"val_loss\": \"N/A\"})\n",
    "\n",
    "    def predict(self, X, y=None):\n",
    "        loader = self.create_loader(X, y if y is not None else torch.zeros(len(X)), is_train=False)\n",
    "        self.model.eval()\n",
    "        preds = []\n",
    "        with torch.no_grad():\n",
    "            for xb, _ in loader:\n",
    "                pred = self.model(xb.to(self.device))\n",
    "                preds.append(pred.cpu())\n",
    "        return torch.cat(preds).numpy()\n",
    "\n",
    "    def load_best_model(self, path=None):\n",
    "        path = path or os.path.join(self.ckpt_dir, \"best_model.pt\")\n",
    "        self.model, _ = NAM.load_checkpoint(path, device=self.device)\n",
    "        self.model.to(self.device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "e27173b8-60b2-44e6-92a7-458246f62ecb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing RMSE: 0.49322653\n"
     ]
    }
   ],
   "source": [
    "print(\"Testing RMSE:\", np.sqrt(np.mean((test_pred - y_val.numpy()) ** 2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d320d32-f068-41cc-83f0-6252134c83c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "index_list = [[i] for i in range(X_train.size()[1])] + unique_elements\n",
    "\n",
    "config = NAMConfig(\n",
    "    input_dim=X_train.size()[1],\n",
    "    index_list=index_list,\n",
    "    hidden_sizes=[128, 64, 32],\n",
    "    use_exu=False,\n",
    "    lr = 1e-2,\n",
    "    dropout=0.1,\n",
    "    use_batchnorm=True,\n",
    "    batch_size=1024,\n",
    "    n_epochs=1000,\n",
    "    tqlm=True\n",
    ")\n",
    "trainer = NAMTrainer(config)\n",
    "trainer.fit(X_train, y_train, X_val, y_val)\n",
    "trainer.load_best_model()\n",
    "test_pred = trainer.predict(X_test)\n",
    "print(\"Testing RMSE:\", np.sqrt(np.mean((test_pred - y_test.numpy()) ** 2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7ad86a8-6a8e-4831-9b35-e784a09af9c5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "8abecbe0-fc3a-4c50-9f96-553a7505fa3b",
   "metadata": {},
   "source": [
    "## DNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7df9e2f7-4bad-4f49-8f30-7c8c017aabb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DNNBaseline(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dims=[64, 32]):\n",
    "        super(DNNBaseline, self).__init__()\n",
    "        layers = []\n",
    "        \n",
    "        for h_dim in hidden_dims:\n",
    "            layers.extend([\n",
    "                nn.Linear(input_dim, h_dim),\n",
    "                nn.ReLU()\n",
    "            ])\n",
    "            input_dim = h_dim\n",
    "            \n",
    "        # Final output layer\n",
    "        layers.append(nn.Linear(input_dim, 1))\n",
    "        self.net = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "    \n",
    "class EarlyStopper:\n",
    "    def __init__(self, patience=1, min_delta=0):\n",
    "        self.patience = patience\n",
    "        self.min_delta = min_delta\n",
    "        self.counter = 0\n",
    "        self.min_validation_loss = float('inf')\n",
    "\n",
    "    def early_stop(self, validation_loss):\n",
    "        if validation_loss < self.min_validation_loss:\n",
    "            self.min_validation_loss = validation_loss\n",
    "            self.counter = 0\n",
    "        elif validation_loss > (self.min_validation_loss + self.min_delta):\n",
    "            self.counter += 1\n",
    "            if self.counter >= self.patience:\n",
    "                return True\n",
    "        return False\n",
    "    \n",
    "def train(model, X_train, y_train, X_val, y_val, file_path, n_epochs = 10000, batch_size=64, lr=1e-2, pt = 50):\n",
    "    loss_fn = nn.MSELoss()\n",
    "    optimizer = optim.Adam(model.parameters(), lr = lr)\n",
    "    scheduler = ReduceLROnPlateau(optimizer, 'min')\n",
    "    early_stopping = EarlyStopper(patience = pt)\n",
    "\n",
    "    # Training loop\n",
    "    n_epochs = n_epochs\n",
    "    batch_size = X_train.size()[0]\n",
    "    batch_start = torch.arange(0, len(X_train), batch_size)\n",
    "\n",
    "    best_mse = float('inf')\n",
    "    best_weights = None\n",
    "\n",
    "    start_time = time.time()\n",
    "    patient = 0\n",
    "    # Training loop\n",
    "\n",
    "    for epoch in range(n_epochs):\n",
    "        model.train()\n",
    "        \n",
    "        for start in batch_start:\n",
    "            X_batch = X_train[start:start+batch_size]\n",
    "            y_batch = y_train[start:start+batch_size]\n",
    "            \n",
    "            # Forward pass\n",
    "            y_pred = model(X_batch)\n",
    "            loss = loss_fn(y_pred, y_batch)\n",
    "            \n",
    "            # Backward pass\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        \n",
    "        # Evaluate model on test set at the end of each epoch\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            y_pred = model(X_val)\n",
    "            val_loss = loss_fn(y_pred, y_val)\n",
    "            val_loss = float(val_loss)\n",
    "            scheduler.step(val_loss)\n",
    "            \n",
    "            if early_stopping.early_stop(val_loss):\n",
    "                print(f\"Early Stop at Epoch {epoch}\")\n",
    "                break\n",
    "            \n",
    "            if val_loss < best_mse:\n",
    "                best_mse = val_loss\n",
    "                best_weights = copy.deepcopy(model.state_dict())\n",
    "            \n",
    "        '''  \n",
    "        if patient == pt:\n",
    "            print(f\"Early Stop at Epoch {epoch}\")\n",
    "            break\n",
    "        '''\n",
    "        \n",
    "        if (epoch+1) % 1000 == 0:\n",
    "            print(f\"Epoch {epoch+1}, MSE: {val_loss}\")\n",
    "\n",
    "    end_time = time.time()\n",
    "    Time_consumption = end_time - start_time\n",
    "    torch.save(best_weights, file_path)\n",
    "\n",
    "    return Time_consumption\n",
    "\n",
    "def eval_model(model, path, X, y):\n",
    "    model.load_state_dict(torch.load(path, weights_only=True))\n",
    "    model.eval()\n",
    "\n",
    "    loss_fn = nn.MSELoss()\n",
    "    with torch.no_grad():\n",
    "        y_pred = model(X)\n",
    "        mse = np.sqrt(loss_fn(y, y_pred))\n",
    "\n",
    "    return mse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "d24ecf3b-f436-47c5-a21e-1eba9d7a8899",
   "metadata": {},
   "outputs": [],
   "source": [
    "case = 'wine'\n",
    "path = '/home/users/yhung7/SDAM/Dataset/Real-Data-Application/'\n",
    "Train = pd.read_csv(path+case+'_Train.csv')\n",
    "Valid = pd.read_csv(path+case+'_Valid.csv')\n",
    "Test = pd.read_csv(path+case+'_Test.csv')\n",
    "\n",
    "y_train = Train['y']\n",
    "X_train = Train.iloc[:, :-1]\n",
    "y_val = Valid['y']\n",
    "X_val = Valid.iloc[:, :-1]\n",
    "y_test = Test['y']\n",
    "X_test = Test.iloc[:, :-1]\n",
    "\n",
    "X_train = torch.tensor(X_train.to_numpy(), dtype=torch.float32)\n",
    "y_train = torch.tensor(y_train, dtype=torch.float32)\n",
    "X_val = torch.tensor(X_val.to_numpy(), dtype=torch.float32)\n",
    "y_val = torch.tensor(y_val, dtype=torch.float32)\n",
    "X_test = torch.tensor(X_test.to_numpy(), dtype=torch.float32)\n",
    "y_test = torch.tensor(y_test, dtype=torch.float32)\n",
    "\n",
    "scaler = StandardScaler()\n",
    "\n",
    "X_train = scaler.fit_transform(X_train.cpu().numpy())\n",
    "X_val = scaler.transform(X_val.cpu().numpy())\n",
    "X_test = scaler.transform(X_test.cpu().numpy())\n",
    "\n",
    "# 4. 轉回 Tensor\n",
    "X_train = torch.tensor(X_train, dtype=torch.float32)\n",
    "X_val = torch.tensor(X_val, dtype=torch.float32)\n",
    "X_test = torch.tensor(X_test, dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "8ed5913f-0a86-47eb-b568-2a8cb1f294be",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Early Stop at Epoch 517\n"
     ]
    }
   ],
   "source": [
    "DNN = DNNBaseline(input_dim=X_train.size(1), hidden_dims = [128, 64, 32])\n",
    "Runtime_dict = train(DNN, X_train, y_train.view(-1, 1), X_val, y_val.view(-1, 1), 'DNN_'+case+'.pth', n_epochs = 20000, batch_size=1024, lr = 5e-3)\n",
    "MSPE_dict= eval_model(DNN, 'DNN_'+case+'.pth', X_test, y_test.view(-1, 1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "c30c3bce-7174-4ae3-86ca-c29eeebb8a26",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "58.84652137756348\n",
      "tensor(0.7002)\n"
     ]
    }
   ],
   "source": [
    "print(Runtime_dict)\n",
    "print(MSPE_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "c96ea320-9e7a-4394-bd2a-0a507e22fbe9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4898\n",
      "11\n"
     ]
    }
   ],
   "source": [
    "print(X_train.size()[0]+X_val.size()[0]+X_test.size()[0])\n",
    "print(X_train.size()[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d02c6c9a-f922-4be6-88cb-0be2045964f4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
