{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "39f55327",
   "metadata": {},
   "source": [
    "## Section 1: Dependency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6f40f833",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: cuda\n",
      "GPU: Tesla V100-SXM2-32GB\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/users/yhung7/.local/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau\n",
    "from torch.cuda.amp import GradScaler, autocast\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import SplineTransformer, StandardScaler\n",
    "from sklearn.model_selection import train_test_split, KFold\n",
    "from sklearn.metrics import mean_squared_error, r2_score\n",
    "from itertools import combinations\n",
    "from tqdm.auto import tqdm\n",
    "from pathlib import Path\n",
    "import time\n",
    "import copy\n",
    "from dataclasses import dataclass\n",
    "from typing import List, Optional, Dict, Tuple\n",
    "import warnings\n",
    "import pandas as pd\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "# Device configuration\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Using device: {device}\")\n",
    "if torch.cuda.is_available():\n",
    "    print(f\"GPU: {torch.cuda.get_device_name(0)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ea8940f2",
   "metadata": {},
   "source": [
    "## Section 2: Splines"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "97041836",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SplineProcessor:\n",
    "    \n",
    "    def __init__(self, n_knots=10, boundary_knots=3, degree=3, custom=False):\n",
    "        self.n_knots = n_knots\n",
    "        self.boundary_knots = boundary_knots\n",
    "        self.degree = degree\n",
    "        self.custom = custom\n",
    "    \n",
    "    @staticmethod\n",
    "    def customized_knots(X: np.ndarray, boundary_knots: int, n_knots: int, \n",
    "                        degree: int, unif: bool = False) -> np.ndarray:\n",
    "        lwbnd, upbnd, Xstd = X.min(), X.max(), X.std()\n",
    "        \n",
    "        if unif:\n",
    "            return np.linspace(lwbnd, upbnd, n_knots).reshape(-1, 1)\n",
    "        \n",
    "        lw_knots = np.linspace(lwbnd - 0.1*Xstd, lwbnd + 0.1*Xstd, boundary_knots)\n",
    "        mid_knots = np.linspace(lwbnd + 0.1*Xstd, upbnd - 0.1*Xstd, \n",
    "                               n_knots - boundary_knots)\n",
    "        up_knots = np.linspace(upbnd - 0.1*Xstd, upbnd + 0.1*Xstd, boundary_knots)\n",
    "        knots = np.unique(np.concatenate([lw_knots, mid_knots, up_knots]))\n",
    "        return knots.reshape(-1, 1)\n",
    "    \n",
    "    def transform(self, X: np.ndarray, extrapolation: str = 'constant') -> np.ndarray:\n",
    "        \"\"\"Œspline transformation\"\"\"\n",
    "        X = np.asarray(X).reshape(-1, 1)\n",
    "        \n",
    "        if self.custom:\n",
    "            knots = self.customized_knots(X, self.boundary_knots, \n",
    "                                         self.n_knots, self.degree)\n",
    "            spline = SplineTransformer(knots=knots, degree=self.degree, \n",
    "                                      extrapolation=extrapolation)\n",
    "        else:\n",
    "            spline = SplineTransformer(n_knots=self.n_knots, degree=self.degree,\n",
    "                                      extrapolation=extrapolation)\n",
    "        \n",
    "        return spline.fit_transform(X)\n",
    "    \n",
    "    @staticmethod\n",
    "    def diff_penalty(num_coefs: int, order: int = 2) -> np.ndarray:\n",
    "        return np.diff(np.eye(num_coefs), n=order, axis=0)\n",
    "    \n",
    "    def smoothing_matrix(self, x: np.ndarray, lam: float = 0.1) -> np.ndarray:\n",
    "        B = self.transform(x)\n",
    "        D = self.diff_penalty(B.shape[1], order=2)\n",
    "        \n",
    "        BTB = B.T @ B\n",
    "        DTD = D.T @ D\n",
    "        S = B @ np.linalg.solve(BTB + lam * DTD, B.T)\n",
    "        return S\n",
    "    \n",
    "def spline_transform(X, n_knots=None, boundary_knots=None, degree=3, \n",
    "                    custom=False, extrapolation='constant'):\n",
    "    processor = SplineProcessor(n_knots, boundary_knots, degree, custom)\n",
    "    return processor.transform(X, extrapolation)\n",
    "\n",
    "def PS_smoothing_matrix(x, lam, n_knots, degree, boundary_knots, custom):\n",
    "    processor = SplineProcessor(n_knots, boundary_knots, degree, custom)\n",
    "    return processor.smoothing_matrix(x, lam)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fb37c66",
   "metadata": {},
   "source": [
    "## SAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fe9bc2f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def SAM(X, y, lam = 0, alpha=0.25, max_iter=10, tol=1e-6, ftol = 1e-3, n_knots=10, boundary_knots = 3, degree=3, custom = False):\n",
    "    n_samples, n_features = X.size()\n",
    "    whole_feature = set(list(range(n_features)))\n",
    "    feature_space = (list(range(n_features)))\n",
    "    flag = [True] * n_features\n",
    "    f = torch.zeros((n_samples, n_features))\n",
    "    R = torch.clone(y)\n",
    "    \n",
    "    for _ in range(max_iter):\n",
    "        f_old = torch.clone(f)\n",
    "        flag = [True] * len(feature_space)\n",
    "        for j in range(len(feature_space)):\n",
    "            u_space_idx = [feature_space[j]]\n",
    "            res_space_idx = list(whole_feature-set(u_space_idx))\n",
    "            Res = R - f[:, res_space_idx].sum(axis=1)\n",
    "            Res = torch.FloatTensor(Res)\n",
    "            \n",
    "            PS_matrix = PS_smoothing_matrix(X[:, j], lam = lam, n_knots = n_knots, degree = degree, boundary_knots = boundary_knots, custom = custom)\n",
    "            PS_matrix = torch.FloatTensor(PS_matrix)\n",
    "\n",
    "            P_j = PS_matrix @ Res\n",
    "            s_j = torch.sqrt(torch.mean(P_j**2))\n",
    "            if s_j > alpha and flag[j]:\n",
    "                f[:, feature_space[j]] =  (1 - alpha / s_j) * P_j\n",
    "            else:\n",
    "                flag[j] = False\n",
    "                f[:, feature_space[j]] = 0\n",
    "            \n",
    "            del PS_matrix\n",
    "            \n",
    "        tfs = []\n",
    "        for b in range(len(flag)):\n",
    "            if flag[b]:\n",
    "                tfs.append(feature_space[b])\n",
    "        feature_space = tfs\n",
    "\n",
    "        if (torch.sum(torch.square(f - f_old)) < tol):\n",
    "            print(f\"Alpha: {alpha:.2f} | Convergence.\")\n",
    "\n",
    "            return f\n",
    "\n",
    "    print(f\"Alpha: {alpha:.2f} | Need more iteration.\")\n",
    "\n",
    "    return f\n",
    "\n",
    "def train_SAM(X, y, alpha_list, max_iter, nk, nb, custom):\n",
    "\n",
    "    Max_L1 = torch.zeros((len(alpha_list)))\n",
    "    component_list = {}\n",
    "    result = {}\n",
    "\n",
    "    for i in range(len(alpha_list)):\n",
    "        f = SAM(X, y, lam = 0.1, alpha = alpha_list[i], max_iter = max_iter, tol=1e-6, ftol = 1e-3, n_knots=nk, boundary_knots = nb, degree=3, custom = custom)\n",
    "        # Identify the non-active features among iteration\n",
    "\n",
    "        nonact_idx = torch.where(torch.sum(torch.square(f), axis = 0) == 0)[0]\n",
    "        Max_L1[i] = torch.sum(torch.norm(f, dim = 0))\n",
    "\n",
    "        component_list[i] = f\n",
    "        component_list[i][:, nonact_idx] = 0\n",
    "\n",
    "    GCV_list, loc, active_dict = estimate_gcv(alpha_list, component_list, X, y)\n",
    "    Max_L1 = torch.max(Max_L1)\n",
    "    result['component'] = component_list\n",
    "    result['GCV'] = GCV_list\n",
    "    result['opt_loc'] = loc\n",
    "    result['opt_var'] = active_dict\n",
    "    result['Max_L1'] = Max_L1\n",
    "    result['alpha'] = alpha_list\n",
    "    \n",
    "    return result\n",
    "\n",
    "## Evaluation best feature \n",
    "\n",
    "def estimate_gcv(alpha, comp, X, y):\n",
    "\n",
    "    def estimate_sigma2(y, y_pred, edf_total):\n",
    "        n = len(y)\n",
    "        rss = torch.sum(torch.square(y - y_pred))\n",
    "        if edf_total > n:\n",
    "            return rss\n",
    "        else:\n",
    "            return rss / (n - edf_total)\n",
    "\n",
    "    def compute_cp(y, y_pred, edf_total, sigma2_hat):\n",
    "        n = len(y)\n",
    "        rss = torch.sum(torch.square(y - y_pred))\n",
    "        cp = (rss / n) + 2 * (sigma2_hat / n) * edf_total\n",
    "        return cp\n",
    "\n",
    "    CP = torch.zeros_like(alpha)\n",
    "    n_samples = y.size()[0]\n",
    "    Critical_point = None\n",
    "    \n",
    "    for i in range(len(alpha)-1, -1, -1):\n",
    "        pred_y = comp[i].sum(axis = 1)\n",
    "        \n",
    "        active_set = torch.where(torch.norm(comp[i], p = 1, dim=0) != 0)[0].tolist()\n",
    "        effective_df = 0\n",
    "        for idx in range(len(active_set)):\n",
    "            PS_matrix = PS_smoothing_matrix(X[:, active_set[idx]], lam = 0.1, n_knots = 10, degree = 3, boundary_knots = 3, custom = False)    \n",
    "            #effective_df += (np.trace(PS_matrix)/n_samples)\n",
    "            effective_df += np.trace(PS_matrix)\n",
    "\n",
    "            del PS_matrix\n",
    "\n",
    "        rss = torch.sum(torch.square(y - pred_y))\n",
    "        # Condition for over-fitting\n",
    "        if effective_df > n_samples:\n",
    "            sigma2_hat = rss\n",
    "            if Critical_point == None:\n",
    "                Critical_point = i\n",
    "        else:\n",
    "            sigma2_hat = estimate_sigma2(y, pred_y, effective_df)\n",
    "\n",
    "        cp_value = compute_cp(y, pred_y, effective_df, sigma2_hat)\n",
    "        \n",
    "        CP[i] = cp_value\n",
    "\n",
    "    # Simplify the plot\n",
    "    CP[:Critical_point] = torch.max(CP[Critical_point:])   \n",
    "    minCP = torch.where(CP == torch.min(CP))[0]\n",
    "    optimalloc_ = minCP.item() if len(minCP) == 1 else minCP[0].item()\n",
    "    optimalset_ = torch.where(torch.norm(comp[optimalloc_], p = 1, dim=0) != 0)[0].tolist()\n",
    "    \n",
    "    return CP, optimalloc_, optimalset_\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "481bccb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def report(alpha, result, truey):\n",
    "\n",
    "    summary = {}\n",
    "    act_idx_list = []\n",
    "    rmse_list = torch.zeros((len(alpha)))\n",
    "    for i in range(len(alpha)):\n",
    "        active_idx = torch.where( torch.norm(result['component'][i], dim = 0) != 0)[0]\n",
    "        act_idx_list.append(torch.where( torch.norm(result['component'][i], dim = 0) != 0)[0])\n",
    "        rmse_list[i] = np.sqrt(mean_squared_error(result['component'][i].sum(dim = 1), truey))\n",
    "\n",
    "        print(f\"Alpha: {alpha[i]:3f} | # Active index: {active_idx.size()[0]} | RMSE: {rmse_list[i]:3f} |\")\n",
    "\n",
    "    summary['alpha'] = alpha\n",
    "    summary['act_idx'] = act_idx_list\n",
    "    summary['rmse'] = rmse_list\n",
    "\n",
    "    return summary\n",
    "\n",
    "def extract_active_features(X: torch.tensor, active_idx: list[int]) -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    Extract active features from X based on active indices.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    X : np.ndarray\n",
    "        Data matrix of shape (n_samples, n_features)\n",
    "    active_idx : list[int]\n",
    "        Indices of active features\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    pd.DataFrame\n",
    "        DataFrame with columns named x1, x2, ..., for active features\n",
    "    \"\"\"\n",
    "    data = pd.DataFrame({f\"x{idx+1}\": X[:, idx] for idx in active_idx})\n",
    "    interaction = list(combinations(active_idx, 2))\n",
    "    return data, interaction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb9c9e5b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f8af6c87",
   "metadata": {},
   "source": [
    "## Loading Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "166678af",
   "metadata": {},
   "outputs": [],
   "source": [
    "from os.path import join as pjoin, exists as pexists\n",
    "\n",
    "import os\n",
    "import glob\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "class RealDatasetLoader:\n",
    "    \"\"\"Utility loader for several real datasets used in experiments.\n",
    "\n",
    "    Methods return a dict containing at minimum:\n",
    "      - X_train: pandas.DataFrame\n",
    "      - y_train: numpy.ndarray\n",
    "      - X_test: pandas.DataFrame\n",
    "      - y_test: numpy.ndarray\n",
    "      - feature_names: list\n",
    "      - problem: 'regression' (for these datasets)\n",
    "\n",
    "    All methods accept a `normalize` flag which, if True, will z-score\n",
    "    the features using the training set statistics. For compatibility with\n",
    "    existing notebooks this loader returns DataFrames for X (so callers\n",
    "    can later call `.to_numpy()` before converting to torch tensors).\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, base_path=None):\n",
    "        if base_path is None:\n",
    "            # default to repository Real-dataset folder (one level up from src)\n",
    "            base_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'Real-dataset'))\n",
    "        self.base_path = base_path\n",
    "\n",
    "    def _normalize(self, X_train, X_test):\n",
    "        # X_train/test are pandas DataFrames\n",
    "        x_max = X_train.max(axis=0)\n",
    "        x_min = X_train.min(axis=0)\n",
    "        X_train_n = (X_train - x_min) / (x_max - x_min)\n",
    "        X_test_n = (X_test - x_min) /  (x_max - x_min)\n",
    "        return X_train_n, X_test_n\n",
    "\n",
    "    def _standardize(self, X_train, X_test):\n",
    "        # X_train/test are pandas DataFrames\n",
    "        mu = X_train.mean(axis=0)\n",
    "        std = X_train.std(axis=0).replace(0, 1.0)\n",
    "        X_train_n = (X_train - mu) / std\n",
    "        X_test_n = (X_test - mu) / std\n",
    "        return X_train_n, X_test_n\n",
    "    \n",
    "    def load_wine(self, fold=0, label_col='quality', infile=None, normalize=False, standardize = False):\n",
    "        \"\"\"Load winequality-white dataset and return train/test split by fold files.\n",
    "\n",
    "        Expects files under `<base_path>/wine/`:\n",
    "          - `winequality-white.csv` (semicolon-separated)\n",
    "          - `train{fold}.txt` and `test{fold}.txt` with row indices\n",
    "        \"\"\"\n",
    "        folder = os.path.join(self.base_path, 'wine')\n",
    "        if infile is None:\n",
    "            infile = os.path.join(folder, 'winequality-white.csv')\n",
    "        if not os.path.exists(infile):\n",
    "            raise FileNotFoundError(f\"Wine file not found at {infile}\")\n",
    "\n",
    "        # winequality files use ';' as separator\n",
    "        df = pd.read_csv(infile, sep=';')\n",
    "        if label_col not in df.columns:\n",
    "            raise KeyError(f\"Label column '{label_col}' not found in {infile}\")\n",
    "\n",
    "        train_idx_file = os.path.join(folder, f'train{fold}.txt')\n",
    "        test_idx_file = os.path.join(folder, f'test{fold}.txt')\n",
    "        if not os.path.exists(train_idx_file) or not os.path.exists(test_idx_file):\n",
    "            raise FileNotFoundError('Train/test split files not found under %s' % folder)\n",
    "\n",
    "        train_idx = pd.read_csv(train_idx_file, header=None)[0].values\n",
    "        test_idx = pd.read_csv(test_idx_file, header=None)[0].values\n",
    "\n",
    "        feature_cols = [c for c in df.columns if c != label_col]\n",
    "        X = df[feature_cols]\n",
    "        y = df[label_col].values.astype(np.float32)\n",
    "\n",
    "        X_train = X.iloc[train_idx].reset_index(drop=True)\n",
    "        X_test = X.iloc[test_idx].reset_index(drop=True)\n",
    "        y_train = y[train_idx]\n",
    "        y_test = y[test_idx]\n",
    "\n",
    "        if normalize:\n",
    "            X_train, X_test = self._normalize(X_train, X_test)\n",
    "        elif standardize:\n",
    "            X_train, X_test = self._standardize(X_train, X_test)\n",
    "        else:\n",
    "            pass\n",
    "\n",
    "        return dict(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test,\n",
    "                    feature_names=list(feature_cols), problem='regression')\n",
    "\n",
    "    def load_bikeshare(self, fold=0, hour=True, infile=None, normalize=False, standardize = False):\n",
    "        \"\"\"Load bikeshare dataset. If `hour` True uses `hour.csv`, otherwise `day.csv`.\n",
    "\n",
    "        Expects split indices under `<base_path>/bikeshare/train{fold}.txt` etc.\n",
    "        \"\"\"\n",
    "        folder = os.path.join(self.base_path, 'bikeshare')\n",
    "        fname = 'hour.csv' if hour else 'day.csv'\n",
    "        if infile is None:\n",
    "            infile = os.path.join(folder, fname)\n",
    "        if not os.path.exists(infile):\n",
    "            raise FileNotFoundError(f\"Bikeshare file not found at {infile}\")\n",
    "\n",
    "        df = pd.read_csv(infile).set_index('instant')\n",
    "        train_cols = ['season', 'yr', 'mnth', 'hr', 'holiday', 'weekday',\n",
    "                      'workingday', 'weathersit', 'temp', 'atemp', 'hum', 'windspeed']\n",
    "        label = 'cnt'\n",
    "        for c in train_cols + [label]:\n",
    "            if c not in df.columns:\n",
    "                raise KeyError(f\"Expected column '{c}' not found in bikeshare data\")\n",
    "\n",
    "        X = df[train_cols]\n",
    "        y = df[label].values.astype(np.float32)\n",
    "\n",
    "        train_idx_file = os.path.join(folder, f'train{fold}.txt')\n",
    "        test_idx_file = os.path.join(folder, f'test{fold}.txt')\n",
    "        if not os.path.exists(train_idx_file) or not os.path.exists(test_idx_file):\n",
    "            raise FileNotFoundError('Train/test split files not found under %s' % folder)\n",
    "\n",
    "        train_idx = pd.read_csv(train_idx_file, header=None)[0].values\n",
    "        test_idx = pd.read_csv(test_idx_file, header=None)[0].values\n",
    "\n",
    "        X_train = X.iloc[train_idx].reset_index(drop=True)\n",
    "        X_test = X.iloc[test_idx].reset_index(drop=True)\n",
    "        y_train = y[train_idx]\n",
    "        y_test = y[test_idx]\n",
    "\n",
    "        if normalize:\n",
    "            X_train, X_test = self._normalize(X_train, X_test)\n",
    "        elif standardize:\n",
    "            X_train, X_test = self._standardize(X_train, X_test)\n",
    "        else:\n",
    "            pass\n",
    "\n",
    "        return dict(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test,\n",
    "                    feature_names=train_cols, problem='regression')\n",
    "\n",
    "    def load_california_housing(self, test_size=0.2, random_state=0, normalize=False, standardize = False):\n",
    "        \"\"\"Fetch California housing from sklearn and split into train/test.\"\"\"\n",
    "        try:\n",
    "            from sklearn.datasets import fetch_california_housing\n",
    "        except Exception as e:\n",
    "            raise RuntimeError('scikit-learn is required to load California housing') from e\n",
    "\n",
    "        housing = fetch_california_housing()\n",
    "        X = pd.DataFrame(housing.data, columns=housing.feature_names)\n",
    "        y = housing.target.astype(np.float32)\n",
    "\n",
    "        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size,\n",
    "                                                            random_state=random_state)\n",
    "        if normalize:\n",
    "            X_train, X_test = self._normalize(X_train, X_test)\n",
    "        elif standardize:\n",
    "            X_train, X_test = self._standardize(X_train, X_test)\n",
    "        else:\n",
    "            pass\n",
    "\n",
    "        return dict(X_train=X_train.reset_index(drop=True), y_train=y_train,\n",
    "                    X_test=X_test.reset_index(drop=True), y_test=y_test,\n",
    "                    feature_names=housing.feature_names, problem='regression')\n",
    "\n",
    "    def load_fico(self, file_path=None, label_col=None, test_size=0.2, random_state=0, normalize=False, standardize = False):\n",
    "        \"\"\"Load a FICO dataset CSV/Excel. If `file_path` is None we search the base_path\n",
    "        for files containing 'fico' in their filename. The caller must provide a `label_col`\n",
    "        if the dataset doesn't use a standard column name.\n",
    "        \"\"\"\n",
    "\n",
    "        file_path = os.path.join(self.base_path, 'fico.csv')\n",
    "        # try CSV then Excel\n",
    "        df = pd.read_csv(file_path, index_col=0)\n",
    "    \n",
    "\n",
    "        # infer label column if missing\n",
    "        if label_col is None:\n",
    "            if 'target' in df.columns:\n",
    "                label_col = 'target'\n",
    "            else:\n",
    "                # fallback to last column\n",
    "                label_col = df.columns[-1]\n",
    "\n",
    "        if label_col not in df.columns:\n",
    "            raise KeyError(f\"Label column '{label_col}' not found in {file_path}\")\n",
    "\n",
    "        feature_cols = [c for c in df.columns if c != label_col]\n",
    "        X = df[feature_cols]\n",
    "        y = df[label_col].values.astype(np.float32)\n",
    "\n",
    "        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size,\n",
    "                                                            random_state=random_state)\n",
    "        if normalize:\n",
    "            X_train, X_test = self._normalize(X_train, X_test)\n",
    "        elif standardize:\n",
    "            X_train, X_test = self._standardize(X_train, X_test)\n",
    "        else:\n",
    "            pass\n",
    "\n",
    "        return dict(X_train=X_train.reset_index(drop=True), y_train=y_train,\n",
    "                    X_test=X_test.reset_index(drop=True), y_test=y_test,\n",
    "                    feature_names=feature_cols, problem='regression')\n",
    "    \n",
    "path = '/home/users/yhung7/SDAM/Additional_exp'\n",
    "Dataset = RealDatasetLoader(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c09f4592",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X_train size: torch.Size([7896, 25]), y_train size: torch.Size([7896])\n",
      "X_test size: torch.Size([1975, 25]), y_test size: torch.Size([1975])\n"
     ]
    }
   ],
   "source": [
    "case = 'fico'\n",
    "\n",
    "if case == 'wine':\n",
    "    dataset_nrm = Dataset.load_wine(normalize=True)\n",
    "    dataset_std = Dataset.load_wine(standardize=True)\n",
    "elif case == 'bike':\n",
    "    dataset_nrm = Dataset.load_wine(normalize=True)\n",
    "    dataset_std = Dataset.load_wine(standardize=True)\n",
    "elif case == 'ca':\n",
    "    dataset_nrm = Dataset.load_california_housing(normalize=True)\n",
    "    dataset_std = Dataset.load_wine(standardize=True)\n",
    "elif case == 'fico':\n",
    "    dataset_nrm = Dataset.load_fico(normalize=True)\n",
    "    dataset_std = Dataset.load_fico(standardize=True)\n",
    "else:\n",
    "    pass\n",
    "\n",
    "X_train = torch.from_numpy(dataset_nrm['X_train'].to_numpy())\n",
    "X_test = torch.from_numpy(dataset_nrm['X_test'].to_numpy())\n",
    "\n",
    "## Normalization\n",
    "y_train = torch.tensor(dataset_nrm['y_train']).to(torch.float32)\n",
    "y_test = torch.tensor(dataset_nrm['y_test']).to(torch.float32)\n",
    "\n",
    "\n",
    "print(f\"X_train size: {X_train.shape}, y_train size: {y_train.shape}\")\n",
    "print(f\"X_test size: {X_test.shape}, y_test size: {y_test.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdac3002-7a9c-467d-89c3-391d6d57a081",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dd2d8a2a-1182-469d-b197-10771fa86d8f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device:  cuda\n"
     ]
    }
   ],
   "source": [
    "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n",
    "\n",
    "dataset = torch.load('/home/users/yhung7/SDAM/data/only_main_data.pt', weights_only = True)\n",
    "    \n",
    "length, sample, feature = dataset['X_train'].size()\n",
    "mse_list = np.zeros(length)\n",
    "random_state = 0\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "print('device: ', device)\n",
    "\n",
    "length = dataset['X_train'].size()[0]\n",
    "length = 1\n",
    "start_time = time.time()\n",
    "for i in range(length):\n",
    "    \n",
    "    X_train = dataset['X_train'][i]\n",
    "    y_train = dataset['y_train'][i]\n",
    "    X_test = dataset['X_test'][i]\n",
    "    y_test = dataset['y_test'][i]\n",
    "    X_val = dataset['X_valid'][i]\n",
    "    y_val = dataset['y_valid'][i]\n",
    "    scaler = MinMaxScaler(feature_range=(0, 1))\n",
    "\n",
    "    X_train = scaler.fit_transform(X_train.cpu().numpy())\n",
    "    \n",
    "    # 4. 轉回 PyTorch Tensor\n",
    "    X_train = torch.tensor(X_train, dtype=torch.float32)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aceefce-d40c-4578-b957-a8e4a7a06fad",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8eb3e6fe-b307-405b-b2a6-99016fb5d236",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "1541a7bb",
   "metadata": {},
   "source": [
    "## Main function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33278a2e-5132-42d3-898c-532694ce7397",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import MinMaxScaler, StandardScaler\n",
    "\n",
    "case = 'bike'\n",
    "path = '../../Dataset/Real-Data-Application/'\n",
    "\n",
    "Train = pd.read_csv(path+case+'_Train.csv')\n",
    "Valid = pd.read_csv(path+case+'_Valid.csv')\n",
    "Test = pd.read_csv(path+case+'_Test.csv')\n",
    "\n",
    "y_train = Train['y']\n",
    "X_train = Train.iloc[:, :-1]\n",
    "y_val = Valid['y']\n",
    "X_val = Valid.iloc[:, :-1]\n",
    "y_test = Test['y']\n",
    "X_test = Test.iloc[:, :-1]\n",
    "\n",
    "X_train = torch.tensor(X_train.to_numpy(), dtype=torch.float32)\n",
    "y_train = torch.tensor(y_train, dtype=torch.float32)\n",
    "X_val = torch.tensor(X_val.to_numpy(), dtype=torch.float32)\n",
    "y_val = torch.tensor(y_val, dtype=torch.float32)\n",
    "X_test = torch.tensor(X_test.to_numpy(), dtype=torch.float32)\n",
    "y_test = torch.tensor(y_test, dtype=torch.float32)\n",
    "\n",
    "scaler = MinMaxScaler(feature_range=(0, 1))\n",
    "\n",
    "X_train = scaler.fit_transform(X_train.cpu().numpy())\n",
    "\n",
    "# 4. PyTorch Tensor\n",
    "X_train = torch.tensor(X_train, dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b50025f-75a0-460d-a82c-0874ab2fcb1b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "ba6a1d4f",
   "metadata": {},
   "source": [
    "### SCREENING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "13580503",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Alpha: 0.00 | Need more iteration.\n",
      "Alpha: 0.00 | Need more iteration.\n",
      "Alpha: 0.01 | Need more iteration.\n",
      "Alpha: 0.05 | Need more iteration.\n",
      "Alpha: 0.50 | Need more iteration.\n",
      "Alpha: 3.50 | Need more iteration.\n",
      "Alpha: 4.50 | Need more iteration.\n",
      "Alpha: 100.00 | Need more iteration.\n",
      "Alpha: 200.00 | Convergence.\n",
      "Alpha: 1000.00 | Convergence.\n"
     ]
    }
   ],
   "source": [
    "if case == 'wine':\n",
    "    alpha_list = torch.tensor(list(np.arange(0.1, 10.1, 1.5))+ [25, 50, 100])\n",
    "elif case == 'ca':\n",
    "    alpha_list = torch.tensor(list(np.arange(0.01, 2.5, 0.3))+ [5])\n",
    "elif case == 'bike':\n",
    "    alpha_list = torch.tensor([0.001, 0.005, 0.01, 0.05, 0.5, 3.5, 4.5, 100, 150, 200])\n",
    "elif case == 'fico':\n",
    "    alpha_list = torch.tensor(list(np.arange(0.005, 0.056, 0.01))+ list(np.arange(0.055, 0.51, 0.1)))\n",
    "else:\n",
    "    pass\n",
    "results_ = train_SAM(X_train[:800,:], y_train[:800], alpha_list, max_iter = 50, nk = 8, nb = 3, custom = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "d9630097",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Alpha: 0.001000 | # Active index: 12 | RMSE: 101.234840 |\n",
      "Alpha: 0.005000 | # Active index: 9 | RMSE: 103.758987 |\n",
      "Alpha: 0.010000 | # Active index: 9 | RMSE: 103.758667 |\n",
      "Alpha: 0.050000 | # Active index: 6 | RMSE: 111.208359 |\n",
      "Alpha: 0.500000 | # Active index: 6 | RMSE: 111.227547 |\n",
      "Alpha: 3.500000 | # Active index: 4 | RMSE: 112.146469 |\n",
      "Alpha: 4.500000 | # Active index: 4 | RMSE: 112.266945 |\n",
      "Alpha: 100.000000 | # Active index: 3 | RMSE: 192.103516 |\n",
      "Alpha: 200.000000 | # Active index: 0 | RMSE: 255.871506 |\n",
      "Alpha: 1000.000000 | # Active index: 0 | RMSE: 255.871506 |\n"
     ]
    }
   ],
   "source": [
    "summary = report(alpha_list, results_, y_train[:800])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "417b3add",
   "metadata": {},
   "source": [
    "## Higher-order detection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "339c918f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau\n",
    "from torch.cuda.amp import GradScaler, autocast\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.preprocessing import SplineTransformer, StandardScaler\n",
    "from sklearn.model_selection import KFold, cross_validate\n",
    "from sklearn.linear_model import LassoCV, Lasso, Ridge\n",
    "from sklearn.metrics import mean_squared_error, r2_score\n",
    "from itertools import combinations\n",
    "from typing import List, Tuple, Dict, Optional\n",
    "from dataclasses import dataclass\n",
    "from tqdm.auto import tqdm\n",
    "from pathlib import Path\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# ============================================================\n",
    "# Section 1: Optimized Spline Basis Construction\n",
    "# ============================================================\n",
    "\n",
    "class OptimizedSplineBasis:\n",
    "    \"\"\"\n",
    "    Optimized spline basis construction with caching and vectorization.\n",
    "    Significantly faster than original implementation.\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, n_splines: int = 10, spline_degree: int = 3, \n",
    "                 include_intercept: bool = False, cache_size: int = 100):\n",
    "        self.n_splines = n_splines\n",
    "        self.spline_degree = spline_degree\n",
    "        self.include_intercept = include_intercept\n",
    "        \n",
    "        # Cache for spline transformers to avoid repeated fitting\n",
    "        self.cache = {}\n",
    "        self.cache_size = cache_size\n",
    "    \n",
    "    def _get_cache_key(self, shape_str: str) -> str:\n",
    "        \"\"\"Generate cache key for basis function.\"\"\"\n",
    "        return f\"spline_{self.n_splines}_{self.spline_degree}_{shape_str}\"\n",
    "    \n",
    "    def build_univariate_basis(self, x: np.ndarray) -> np.ndarray:\n",
    "        \"\"\"\n",
    "        Build spline basis for one variable with caching.\n",
    "        \n",
    "        Args:\n",
    "            x: 1D array of shape (n_samples,)\n",
    "            \n",
    "        Returns:\n",
    "            Spline basis matrix of shape (n_samples, n_basis_functions)\n",
    "        \"\"\"\n",
    "        x = np.asarray(x).reshape(-1, 1)\n",
    "        \n",
    "        # Try cache first\n",
    "        cache_key = self._get_cache_key(f\"{x.shape[0]}_{x.min():.2f}_{x.max():.2f}\")\n",
    "        \n",
    "        sp = SplineTransformer(\n",
    "            degree=self.spline_degree,\n",
    "            n_knots=self.n_splines,\n",
    "            include_bias=self.include_intercept\n",
    "        )\n",
    "        \n",
    "        basis = sp.fit_transform(x)\n",
    "        self.cache[cache_key] = basis\n",
    "        \n",
    "        return basis\n",
    "    \n",
    "    def build_bivariate_basis(self, x1: np.ndarray, x2: np.ndarray) -> np.ndarray:\n",
    "        \"\"\"\n",
    "        Build tensor product spline basis for interactions (optimized).\n",
    "        \n",
    "        Uses efficient tensor product instead of einsum for large arrays.\n",
    "        \n",
    "        Args:\n",
    "            x1, x2: 1D arrays of shape (n_samples,)\n",
    "            \n",
    "        Returns:\n",
    "            Interaction basis matrix of shape (n_samples, n_basis1 * n_basis2)\n",
    "        \"\"\"\n",
    "        B1 = self.build_univariate_basis(x1)\n",
    "        B2 = self.build_univariate_basis(x2)\n",
    "        \n",
    "        # Efficient tensor product: (n, p1) x (n, p2) -> (n, p1*p2)\n",
    "        n_samples = B1.shape[0]\n",
    "        n_basis1, n_basis2 = B1.shape[1], B2.shape[1]\n",
    "        \n",
    "        # Method 1: Vectorized outer product (memory efficient)\n",
    "        result = np.zeros((n_samples, n_basis1 * n_basis2))\n",
    "        for i in range(n_samples):\n",
    "            result[i] = np.outer(B1[i], B2[i]).ravel()\n",
    "        \n",
    "        return result\n",
    "    \n",
    "    def build_trivariate_basis(self, x1: np.ndarray, x2: np.ndarray, \n",
    "                               x3: np.ndarray) -> np.ndarray:\n",
    "        \"\"\"\n",
    "        Build 3-way interaction basis (higher-order).\n",
    "        \n",
    "        Args:\n",
    "            x1, x2, x3: 1D arrays of shape (n_samples,)\n",
    "            \n",
    "        Returns:\n",
    "            3-way interaction basis\n",
    "        \"\"\"\n",
    "        B1 = self.build_univariate_basis(x1)\n",
    "        B2 = self.build_univariate_basis(x2)\n",
    "        B3 = self.build_univariate_basis(x3)\n",
    "        \n",
    "        n_samples = B1.shape[0]\n",
    "        n_basis1, n_basis2, n_basis3 = B1.shape[1], B2.shape[1], B3.shape[1]\n",
    "        \n",
    "        result = np.zeros((n_samples, n_basis1 * n_basis2 * n_basis3))\n",
    "        for i in range(n_samples):\n",
    "            tensor = np.outer(np.outer(B1[i], B2[i]).ravel(), B3[i]).ravel()\n",
    "            result[i] = tensor\n",
    "        \n",
    "        return result\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "c8e3ec73",
   "metadata": {},
   "outputs": [],
   "source": [
    "from skglm import GroupLasso as SkglmGroupLasso\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "@dataclass\n",
    "class InteractionConfig:\n",
    "    \"\"\"Configuration for interaction detection\"\"\"\n",
    "    n_splines: int = 10\n",
    "    spline_degree: int = 3\n",
    "    include_intercept: bool = False\n",
    "    cv_folds: int = 5\n",
    "    max_interaction_order: int = 2\n",
    "    lambda_range: np.ndarray = None\n",
    "    group_reg_range: np.ndarray = None\n",
    "    group_reg: float = 0.1  # For GroupLasso\n",
    "    l1_reg: float = 0.0     # For GroupLasso L1 penalty\n",
    "    random_state: int = 42\n",
    "    verbose: bool = True\n",
    "    model_verbose: bool = False\n",
    "    def __post_init__(self):\n",
    "        if self.lambda_range is None:\n",
    "            self.lambda_range = np.logspace(-4, 1, 15)\n",
    "\n",
    "        if self.group_reg_range is None:\n",
    "            self.group_reg_range = np.logspace(-1, 1, 10)\n",
    "            ############### If case is bike ##############\n",
    "            #self.group_reg_range = np.logspace(0.6, 1, 5)\n",
    "            ##############################################\n",
    "class OptimizedAdditiveInteractionSelector:\n",
    "    \"\"\"\n",
    "    COMPLETE interaction selector supporting ALL 5 methods:\n",
    "    1. group_lasso - sklearn GroupLasso (requires: pip install group-lasso)\n",
    "    2. group_lasso_torch - PyTorch GPU-accelerated GroupLasso â­\n",
    "    3. elastic_net - ElasticNet (default, no external dependencies)\n",
    "    4. lasso - Lasso\n",
    "    5. ridge - Ridge\n",
    "    \"\"\"\n",
    "    \n",
    "    def __init__(self, config: Optional[InteractionConfig] = None):\n",
    "        self.config = config or InteractionConfig()\n",
    "        self.spline_basis = OptimizedSplineBasis(\n",
    "            n_splines=self.config.n_splines,\n",
    "            spline_degree=self.config.spline_degree,\n",
    "            include_intercept=self.config.include_intercept\n",
    "        )\n",
    "        \n",
    "        self.groups = []\n",
    "        self.group_names = []\n",
    "        self.group_indices = []\n",
    "        self.scaler = None\n",
    "        self.design_matrix_ = None\n",
    "        self.model = None\n",
    "        self.group_norms_ = None\n",
    "        self.coefficients_ = None\n",
    "        self.interaction_importance_ = None\n",
    "    \n",
    "    def _build_design_matrix(self, X_df: pd.DataFrame, \n",
    "                            interactions: Optional[List[Tuple]] = None) -> np.ndarray:\n",
    "        \"\"\"Build design matrix with groups\"\"\"\n",
    "        blocks = []\n",
    "        col_idx = 0\n",
    "        \n",
    "        # Main effects\n",
    "        for col_name in X_df.columns:\n",
    "            x_col = X_df[col_name].values\n",
    "            B = self.spline_basis.build_univariate_basis(x_col)\n",
    "            \n",
    "            blocks.append(B)\n",
    "            self.groups.append(list(range(col_idx, col_idx + B.shape[1])))\n",
    "            self.group_names.append((col_name,))\n",
    "            self.group_indices.append([X_df.columns.get_loc(col_name)])\n",
    "            \n",
    "            col_idx += B.shape[1]\n",
    "            \n",
    "            if self.config.model_verbose:\n",
    "                print(f\"Main effect: {col_name} ({B.shape[1]} basis)\")\n",
    "        \n",
    "        # Interactions\n",
    "        if interactions:\n",
    "            for int_tuple in interactions:\n",
    "                if len(int_tuple) == 2:\n",
    "                    a, b = int_tuple\n",
    "                    col_a, col_b = X_df.columns[a], X_df.columns[b]\n",
    "                    \n",
    "                    B_int = self.spline_basis.build_bivariate_basis(\n",
    "                        X_df[col_a].values, X_df[col_b].values\n",
    "                    )\n",
    "                    \n",
    "                    blocks.append(B_int)\n",
    "                    self.groups.append(list(range(col_idx, col_idx + B_int.shape[1])))\n",
    "                    self.group_names.append((col_a, col_b))\n",
    "                    self.group_indices.append([a, b])\n",
    "                    \n",
    "                    col_idx += B_int.shape[1]\n",
    "                    \n",
    "                    if self.config.model_verbose:\n",
    "                        print(f\"2-way: {col_a} & {col_b} ({B_int.shape[1]} basis)\")\n",
    "        \n",
    "        self.design_matrix_ = np.hstack(blocks)\n",
    "        \n",
    "        if self.config.verbose:\n",
    "            print(f\"\\nDesign matrix: {self.design_matrix_.shape}\")\n",
    "            print(f\"   Total groups: {len(self.groups)}\")\n",
    "        \n",
    "        return self.design_matrix_\n",
    "    \n",
    "    def fit(self, X: pd.DataFrame, y: np.ndarray, \n",
    "            interactions: Optional[List[Tuple]] = None,\n",
    "            method: str = 'elastic_net') -> 'OptimizedAdditiveInteractionSelector':\n",
    "        \"\"\"\n",
    "        Fit interaction model with specified method\n",
    "        \n",
    "        Args:\n",
    "            X: Input features (DataFrame)\n",
    "            y: Target variable\n",
    "            interactions: List of interaction tuples\n",
    "            method: 'group_lasso', 'group_lasso_torch', 'elastic_net', 'lasso', 'ridge'\n",
    "        \"\"\"\n",
    "        # Build design matrix\n",
    "        Xd = self._build_design_matrix(X, interactions)\n",
    "        \n",
    "        # Standardize\n",
    "        self.scaler = StandardScaler()\n",
    "        Xd_scaled = self.scaler.fit_transform(Xd)\n",
    "        \n",
    "        if self.config.verbose:\n",
    "            print(f\"\\nFitting {method} model...\")\n",
    "        \n",
    "        # Fit based on method\n",
    "        if method == 'group_lasso':\n",
    "            self._fit_group_lasso(Xd_scaled, y)\n",
    "        \n",
    "        elif method == 'elastic_net':\n",
    "            self._fit_elastic_net(Xd_scaled, y)\n",
    "        \n",
    "        elif method == 'lasso':\n",
    "            self._fit_lasso(Xd_scaled, y)\n",
    "        \n",
    "        elif method == 'ridge':\n",
    "            self._fit_ridge(Xd_scaled, y)\n",
    "        \n",
    "        else:\n",
    "            raise ValueError(f\"Unknown method: {method}. Use: group_lasso, group_lasso_torch, elastic_net, lasso, ridge\")\n",
    "        \n",
    "        # Compute group importance\n",
    "        self._compute_group_importance()\n",
    "        \n",
    "        if self.config.verbose:\n",
    "            print(f\"Model fitted successfully\")\n",
    "        \n",
    "        return self\n",
    "    \n",
    "    def _fit_group_lasso(self, X: np.ndarray, y: np.ndarray):\n",
    "\n",
    "        if self.config.verbose:\n",
    "            print(\"Using skglm.GroupLasso\")\n",
    "        \n",
    "        # Convert group list to group array\n",
    "        # skglm expects groups as array where groups[i] = group_id for feature i\n",
    "        group_array = []\n",
    "        for _, group_indices in enumerate(self.groups):\n",
    "            #for idx in group_indices:\n",
    "            group_array.append(len(group_indices))\n",
    "        \n",
    "        best_score, self.model = -np.inf, None\n",
    "        kf = KFold(n_splits=self.config.cv_folds, shuffle=True, random_state=self.config.random_state)\n",
    "\n",
    "\n",
    "        \n",
    "        for lam in self.config.group_reg_range:\n",
    "            print('='*20)\n",
    "            print(f\"Lambda: {lam:.3f}\")\n",
    "            scores = []; cnt = 1\n",
    "            for tr, va in kf.split(X):\n",
    "                print(f\"{cnt}/{self.config.cv_folds} Folds\")\n",
    "                model = SkglmGroupLasso(\n",
    "                    alpha=lam,\n",
    "                    groups=group_array,\n",
    "                    weights=np.ones(len(self.groups)),\n",
    "                    fit_intercept=True,\n",
    "                    tol=1e-4,\n",
    "                    max_iter=10000,\n",
    "                    verbose=self.config.model_verbose\n",
    "                )\n",
    "                model.fit(X[tr], y[tr])\n",
    "                scores.append(model.score(X[va], y[va]))\n",
    "                cnt += 1\n",
    "            if np.mean(scores) > best_score:\n",
    "                best_score = np.mean(scores)\n",
    "                self.model = SkglmGroupLasso(\n",
    "                    alpha=lam,\n",
    "                    groups=group_array,\n",
    "                    weights=np.ones(len(self.groups)),\n",
    "                    fit_intercept=True,\n",
    "                    tol=1e-4,\n",
    "                    max_iter=10000,\n",
    "                    verbose=self.config.model_verbose\n",
    "                )\n",
    "            print('='*20)\n",
    "        self.model.fit(X, y)\n",
    "        self.coefficients_ = self.model.coef_\n",
    "        '''\n",
    "        self.model = SkglmGroupLasso(\n",
    "            alpha=self.config.group_reg,\n",
    "            groups=group_array,\n",
    "            fit_intercept=True,\n",
    "            tol=1e-4,\n",
    "            max_iter=10000,\n",
    "            verbose=self.config.verbose\n",
    "        )\n",
    "        self.model.fit(X, y)\n",
    "        self.coefficients_ = self.model.coef_\n",
    "        '''\n",
    "    \n",
    "    def _fit_elastic_net(self, X: np.ndarray, y: np.ndarray):\n",
    "        \"\"\"Fit with ElasticNet\"\"\"\n",
    "        from sklearn.linear_model import ElasticNetCV\n",
    "        if self.config.verbose:\n",
    "            print(\"Using ElasticNetCV\")\n",
    "        \n",
    "        self.model = ElasticNetCV(\n",
    "            cv=self.config.cv_folds,\n",
    "            alphas=self.config.lambda_range,\n",
    "            l1_ratio=[0.1, 0.3, 0.5, 0.7, 0.9],\n",
    "            max_iter=10000\n",
    "        )\n",
    "        self.model.fit(X, y)\n",
    "        self.coefficients_ = self.model.coef_\n",
    "    \n",
    "    def _fit_lasso(self, X: np.ndarray, y: np.ndarray):\n",
    "        \"\"\"Fit with Lasso\"\"\"\n",
    "        from sklearn.linear_model import LassoCV\n",
    "        if self.config.verbose:\n",
    "            print(\"Using LASSO\")\n",
    "        self.model = LassoCV(cv=self.config.cv_folds, alphas=self.config.lambda_range)\n",
    "        self.model.fit(X, y)\n",
    "        self.coefficients_ = self.model.coef_\n",
    "    \n",
    "    def _fit_ridge(self, X: np.ndarray, y: np.ndarray):\n",
    "        \"\"\"Fit with Ridge\"\"\"\n",
    "        from sklearn.linear_model import RidgeCV\n",
    "        if self.config.verbose:\n",
    "            print(\"Using Ridge\")\n",
    "        self.model = RidgeCV(alphas=self.config.lambda_range, cv=self.config.cv_folds)\n",
    "        self.model.fit(X, y)\n",
    "        self.coefficients_ = self.model.coef_\n",
    "    \n",
    "    def _compute_group_importance(self):\n",
    "        \"\"\"Compute L2 norm for each group\"\"\"\n",
    "        if hasattr(self.model, 'get_group_norms'):\n",
    "            # For TorchGroupLasso\n",
    "            self.group_norms_ = self.model.model.get_group_norms()\n",
    "        else:\n",
    "            # For sklearn methods\n",
    "            self.group_norms_ = []\n",
    "            for group_indices in self.groups:\n",
    "                group_coef = self.coefficients_[group_indices]\n",
    "                norm = np.linalg.norm(group_coef, ord=2)\n",
    "                self.group_norms_.append(norm)\n",
    "            self.group_norms_ = np.array(self.group_norms_)\n",
    "        \n",
    "        self.interaction_importance_ = pd.DataFrame({\n",
    "            'group': self.group_names,\n",
    "            'feature_indices': self.group_indices,\n",
    "            'importance': self.group_norms_\n",
    "        }).sort_values('importance', ascending=False).reset_index(drop=True)\n",
    "    \n",
    "    def get_interaction_ranking(self) -> pd.DataFrame:\n",
    "        \"\"\"Return ranked interactions by importance\"\"\"\n",
    "        return self.interaction_importance_.copy()\n",
    "    \n",
    "    def get_important_interactions(self, threshold: float = None,\n",
    "                                  top_k: Optional[int] = None) -> List[Tuple]:\n",
    "        \"\"\"Get important interactions\"\"\"\n",
    "        if self.interaction_importance_ is None:\n",
    "            raise ValueError(\"Model not fitted. Call fit() first.\")\n",
    "        \n",
    "        df = self.interaction_importance_.copy()\n",
    "        \n",
    "        if top_k is not None:\n",
    "            df = df.head(top_k)\n",
    "        elif threshold is not None:\n",
    "            df = df[df['importance'] > threshold]\n",
    "        \n",
    "        return [tuple(int(i) for i in row) for row in df['feature_indices']]\n",
    "    \n",
    "    def summary(self, top_n: int = 10):\n",
    "        \"\"\"Print summary\"\"\"\n",
    "        print(\"\\n\" + \"=\"*70)\n",
    "        print(\"INTERACTION IMPORTANCE RANKING\")\n",
    "        print(\"=\"*70)\n",
    "        \n",
    "        df = self.interaction_importance_.head(top_n)\n",
    "        \n",
    "        for idx, row in df.iterrows():\n",
    "            group_name = \" & \".join(row['group']) if isinstance(row['group'], tuple) else str(row['group'])\n",
    "            print(f\"{idx+1:2d}. {group_name:30s} | Importance: {row['importance']:.6f}\")\n",
    "        \n",
    "        print(\"=\"*70)\n",
    "\n",
    "# \n",
    "# Section 4: TorchInteractionDetector (Neural Alternative)\n",
    "# \n",
    "\n",
    "class TorchInteractionDetector(nn.Module):\n",
    "    \"\"\"Neural network for interaction detection with attention\"\"\"\n",
    "    \n",
    "    def __init__(self, input_dim: int, hidden_dims: List[int] = [64, 32],\n",
    "                 n_interactions: int = None, dropout: float = 0.2):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.input_dim = input_dim\n",
    "        self.n_interactions = n_interactions or (input_dim * 2)\n",
    "        \n",
    "        # Trunk\n",
    "        trunk_layers = []\n",
    "        in_dim = input_dim\n",
    "        \n",
    "        for h_dim in hidden_dims:\n",
    "            trunk_layers.extend([\n",
    "                nn.Linear(in_dim, h_dim),\n",
    "                nn.BatchNorm1d(h_dim),\n",
    "                nn.GELU(),\n",
    "                nn.Dropout(dropout)\n",
    "            ])\n",
    "            in_dim = h_dim\n",
    "        \n",
    "        self.trunk = nn.Sequential(*trunk_layers)\n",
    "        \n",
    "        # Prediction head\n",
    "        self.interaction_head = nn.Sequential(\n",
    "            nn.Linear(in_dim, self.n_interactions),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(self.n_interactions, 1)\n",
    "        )\n",
    "        \n",
    "        # Attention mechanism\n",
    "        self.attention = nn.Sequential(\n",
    "            nn.Linear(in_dim, self.n_interactions),\n",
    "            nn.Softmax(dim=1)\n",
    "        )\n",
    "    \n",
    "    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n",
    "        trunk_out = self.trunk(x)\n",
    "        predictions = self.interaction_head(trunk_out)\n",
    "        weights = self.attention(trunk_out)\n",
    "        return predictions, weights\n",
    "\n",
    "\n",
    "class TorchInteractionDetectorTrainer:\n",
    "    \"\"\"Trainer for TorchInteractionDetector\"\"\"\n",
    "    \n",
    "    def __init__(self, model: TorchInteractionDetector,\n",
    "                 learning_rate: float = 1e-3, sparsity_weight: float = 0.01):\n",
    "        self.model = model.to(device)\n",
    "        self.optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
    "        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
    "            self.optimizer, T_max=1000, eta_min=1e-6\n",
    "        )\n",
    "        self.criterion = nn.MSELoss()\n",
    "        self.sparsity_weight = sparsity_weight\n",
    "        self.history = {'train_loss': [], 'val_loss': []}\n",
    "    \n",
    "    def fit(self, X_train: torch.Tensor, y_train: torch.Tensor,\n",
    "            X_val: torch.Tensor, y_val: torch.Tensor,\n",
    "            n_epochs: int = 1000, batch_size: int = 128):\n",
    "        \"\"\"Train model\"\"\"\n",
    "        train_dataset = TensorDataset(X_train, y_train)\n",
    "        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
    "        \n",
    "        val_dataset = TensorDataset(X_val, y_val)\n",
    "        val_loader = DataLoader(val_dataset, batch_size=batch_size*2, shuffle=False)\n",
    "        \n",
    "        best_val_loss = float('inf')\n",
    "        patience = 50\n",
    "        patience_counter = 0\n",
    "        \n",
    "        pbar = tqdm(range(n_epochs), desc='Training')\n",
    "        \n",
    "        for epoch in pbar:\n",
    "            # Train\n",
    "            self.model.train()\n",
    "            train_loss = 0.0\n",
    "            \n",
    "            for X_batch, y_batch in train_loader:\n",
    "                X_batch, y_batch = X_batch.to(device), y_batch.to(device)\n",
    "                self.optimizer.zero_grad()\n",
    "                \n",
    "                y_pred, weights = self.model(X_batch)\n",
    "                mse_loss = self.criterion(y_pred, y_batch)\n",
    "                sparsity_loss = weights.abs().mean()\n",
    "                loss = mse_loss + self.sparsity_weight * sparsity_loss\n",
    "                \n",
    "                loss.backward()\n",
    "                self.optimizer.step()\n",
    "                train_loss += mse_loss.item()\n",
    "            \n",
    "            train_loss /= len(train_loader)\n",
    "            \n",
    "            # Validate\n",
    "            self.model.eval()\n",
    "            val_loss = 0.0\n",
    "            with torch.no_grad():\n",
    "                for X_batch, y_batch in val_loader:\n",
    "                    X_batch, y_batch = X_batch.to(device), y_batch.to(device)\n",
    "                    y_pred, _ = self.model(X_batch)\n",
    "                    loss = self.criterion(y_pred, y_batch)\n",
    "                    val_loss += loss.item()\n",
    "            \n",
    "            val_loss /= len(val_loader)\n",
    "            \n",
    "            self.scheduler.step()\n",
    "            self.history['train_loss'].append(train_loss)\n",
    "            self.history['val_loss'].append(val_loss)\n",
    "            \n",
    "            pbar.set_postfix({'train': f'{train_loss:.6f}', 'val': f'{val_loss:.6f}'})\n",
    "            \n",
    "            # Early stopping\n",
    "            if val_loss < best_val_loss:\n",
    "                best_val_loss = val_loss\n",
    "                patience_counter = 0\n",
    "            else:\n",
    "                patience_counter += 1\n",
    "                if patience_counter >= patience:\n",
    "                    print(f\"\\nEarly stopping at epoch {epoch+1}\")\n",
    "                    break\n",
    "        \n",
    "        return self.history\n",
    "    \n",
    "    def get_interaction_importance(self, X: torch.Tensor) -> np.ndarray:\n",
    "        \"\"\"Get average attention weights\"\"\"\n",
    "        self.model.eval()\n",
    "        with torch.no_grad():\n",
    "            X = X.to(device)\n",
    "            _, weights = self.model(X)\n",
    "            importance = weights.mean(dim=0).cpu().numpy()\n",
    "        return importance\n",
    "    \n",
    "config = InteractionConfig(n_splines=8, group_reg=0.1)\n",
    "selector = OptimizedAdditiveInteractionSelector(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57ca90f8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "552f599e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(\"=\"*70)\n",
    "print(\"Optimized Higher-Order Interaction Detection\")\n",
    "print(\"=\"*70 + \"\\n\")\n",
    "\n",
    "# Generate sample data\n",
    "\n",
    "\n",
    "X_train, interaction_pairs = extract_active_features(np.array(X_train), results_['opt_var'])\n",
    "y_train = np.array(y_train)\n",
    "\n",
    "X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)\n",
    "\n",
    "X_test, interaction_pairs = extract_active_features(np.array(X_test), results_['opt_var'])\n",
    "y_test = np.array(y_test)\n",
    "\n",
    "# 1. Fit main effects + 2-way interactions\n",
    "config = InteractionConfig(\n",
    "    n_splines=3,\n",
    "    cv_folds=5,\n",
    "    verbose=True\n",
    ")\n",
    "\n",
    "selector = OptimizedAdditiveInteractionSelector(config)\n",
    "selector.fit(X_train, y_train, interactions=interaction_pairs, method='group_lasso')\n",
    "selector.summary(top_n=6)\n",
    "\n",
    "important = selector.get_important_interactions(top_k=3)\n",
    "print(f\"\\nTop 3 important interactions: {important}\")\n",
    "\n",
    "# 3. Ranking table\n",
    "print(\"\\nFull Ranking:\")\n",
    "print(selector.get_interaction_ranking())\n",
    "\n",
    "ranking_df = selector.get_interaction_ranking()\n",
    "interaction = [i for i in ranking_df[ranking_df['importance']>0]['feature_indices']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a78e5caf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "60453e54",
   "metadata": {},
   "source": [
    "### DNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "578b7931-8840-48c9-bd5f-18c412d845d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "Train = pd.read_csv(path+case+'_Train.csv')\n",
    "Valid = pd.read_csv(path+case+'_Valid.csv')\n",
    "Test = pd.read_csv(path+case+'_Test.csv')\n",
    "\n",
    "y_train = Train['y']\n",
    "X_train = Train.iloc[:, :-1]\n",
    "y_val = Valid['y']\n",
    "X_val = Valid.iloc[:, :-1]\n",
    "y_test = Test['y']\n",
    "X_test = Test.iloc[:, :-1]\n",
    "\n",
    "X_train = torch.tensor(X_train.to_numpy(), dtype=torch.float32)\n",
    "y_train = torch.tensor(y_train, dtype=torch.float32)\n",
    "X_val = torch.tensor(X_val.to_numpy(), dtype=torch.float32)\n",
    "y_val = torch.tensor(y_val, dtype=torch.float32)\n",
    "X_test = torch.tensor(X_test.to_numpy(), dtype=torch.float32)\n",
    "y_test = torch.tensor(y_test, dtype=torch.float32)\n",
    "\n",
    "scaler = StandardScaler()\n",
    "\n",
    "X_train = scaler.fit_transform(X_train.cpu().numpy())\n",
    "X_val = scaler.transform(X_val.cpu().numpy())\n",
    "X_test = scaler.transform(X_test.cpu().numpy())\n",
    "\n",
    "# 4. 轉回 Tensor\n",
    "X_train = torch.tensor(X_train, dtype=torch.float32)\n",
    "X_val = torch.tensor(X_val, dtype=torch.float32)\n",
    "X_test = torch.tensor(X_test, dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7df9e2f7-4bad-4f49-8f30-7c8c017aabb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DNNBaseline(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dims=[64, 32]):\n",
    "        super(DNNBaseline, self).__init__()\n",
    "        layers = []\n",
    "        \n",
    "        for h_dim in hidden_dims:\n",
    "            layers.extend([\n",
    "                nn.Linear(input_dim, h_dim),\n",
    "                nn.ReLU()\n",
    "            ])\n",
    "            input_dim = h_dim\n",
    "            \n",
    "        # Final output layer\n",
    "        layers.append(nn.Linear(input_dim, 1))\n",
    "        self.net = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "    \n",
    "class EarlyStopper:\n",
    "    def __init__(self, patience=1, min_delta=0):\n",
    "        self.patience = patience\n",
    "        self.min_delta = min_delta\n",
    "        self.counter = 0\n",
    "        self.min_validation_loss = float('inf')\n",
    "\n",
    "    def early_stop(self, validation_loss):\n",
    "        if validation_loss < self.min_validation_loss:\n",
    "            self.min_validation_loss = validation_loss\n",
    "            self.counter = 0\n",
    "        elif validation_loss > (self.min_validation_loss + self.min_delta):\n",
    "            self.counter += 1\n",
    "            if self.counter >= self.patience:\n",
    "                return True\n",
    "        return False\n",
    "    \n",
    "def train(model, X_train, y_train, X_val, y_val, file_path, n_epochs = 10000, batch_size=64, lr=1e-2, pt = 50):\n",
    "    loss_fn = nn.MSELoss()\n",
    "    optimizer = optim.Adam(model.parameters(), lr = lr)\n",
    "    scheduler = ReduceLROnPlateau(optimizer, 'min')\n",
    "    early_stopping = EarlyStopper(patience = pt)\n",
    "\n",
    "    # Training loop\n",
    "    n_epochs = n_epochs\n",
    "    batch_size = X_train.size()[0]\n",
    "    batch_start = torch.arange(0, len(X_train), batch_size)\n",
    "\n",
    "    best_mse = float('inf')\n",
    "    best_weights = None\n",
    "\n",
    "    start_time = time.time()\n",
    "    patient = 0\n",
    "    # Training loop\n",
    "\n",
    "    for epoch in range(n_epochs):\n",
    "        model.train()\n",
    "        \n",
    "        for start in batch_start:\n",
    "            X_batch = X_train[start:start+batch_size]\n",
    "            y_batch = y_train[start:start+batch_size]\n",
    "            \n",
    "            # Forward pass\n",
    "            y_pred = model(X_batch)\n",
    "            loss = loss_fn(y_pred, y_batch)\n",
    "            \n",
    "            # Backward pass\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        \n",
    "        # Evaluate model on test set at the end of each epoch\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            y_pred = model(X_val)\n",
    "            val_loss = loss_fn(y_pred, y_val)\n",
    "            val_loss = float(val_loss)\n",
    "            scheduler.step(val_loss)\n",
    "            \n",
    "            if early_stopping.early_stop(val_loss):\n",
    "                print(f\"Early Stop at Epoch {epoch}\")\n",
    "                break\n",
    "            \n",
    "            if val_loss < best_mse:\n",
    "                best_mse = val_loss\n",
    "                best_weights = copy.deepcopy(model.state_dict())\n",
    "            \n",
    "        '''  \n",
    "        if patient == pt:\n",
    "            print(f\"Early Stop at Epoch {epoch}\")\n",
    "            break\n",
    "        '''\n",
    "        \n",
    "        if (epoch+1) % 1000 == 0:\n",
    "            print(f\"Epoch {epoch+1}, MSE: {val_loss}\")\n",
    "\n",
    "    end_time = time.time()\n",
    "    Time_consumption = end_time - start_time\n",
    "    torch.save(best_weights, file_path)\n",
    "\n",
    "    return Time_consumption\n",
    "\n",
    "def eval_model(model, path, X, y):\n",
    "    model.load_state_dict(torch.load(path, weights_only=True))\n",
    "    model.eval()\n",
    "\n",
    "    loss_fn = nn.MSELoss()\n",
    "    with torch.no_grad():\n",
    "        y_pred = model(X)\n",
    "        mse = np.sqrt(loss_fn(y, y_pred))\n",
    "\n",
    "    return mse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d24ecf3b-f436-47c5-a21e-1eba9d7a8899",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = '../../Dataset/Real-Data-Application/'\n",
    "\n",
    "Train = pd.read_csv(path+case+'_Train.csv')\n",
    "Valid = pd.read_csv(path+case+'_Valid.csv')\n",
    "Test = pd.read_csv(path+case+'_Test.csv')\n",
    "\n",
    "\n",
    "y_train = Train['y']\n",
    "X_train = Train.iloc[:, :-1]\n",
    "y_val = Valid['y']\n",
    "X_val = Valid.iloc[:, :-1]\n",
    "y_test = Test['y']\n",
    "X_test = Test.iloc[:, :-1]\n",
    "\n",
    "X_train = torch.tensor(X_train.to_numpy(), dtype=torch.float32)\n",
    "y_train = torch.tensor(y_train, dtype=torch.float32)\n",
    "X_val = torch.tensor(X_val.to_numpy(), dtype=torch.float32)\n",
    "y_val = torch.tensor(y_val, dtype=torch.float32)\n",
    "X_test = torch.tensor(X_test.to_numpy(), dtype=torch.float32)\n",
    "y_test = torch.tensor(y_test, dtype=torch.float32)\n",
    "\n",
    "scaler = StandardScaler()\n",
    "\n",
    "X_train = scaler.fit_transform(X_train.cpu().numpy())\n",
    "X_val = scaler.transform(X_val.cpu().numpy())\n",
    "X_test = scaler.transform(X_test.cpu().numpy())\n",
    "\n",
    "# 4. 轉回 Tensor\n",
    "X_train = torch.tensor(X_train, dtype=torch.float32)\n",
    "X_val = torch.tensor(X_val, dtype=torch.float32)\n",
    "X_test = torch.tensor(X_test, dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "8ed5913f-0a86-47eb-b568-2a8cb1f294be",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Early Stop at Epoch 517\n"
     ]
    }
   ],
   "source": [
    "DNN = DNNBaseline(input_dim=X_train.size(1), hidden_dims = [128, 64, 32])\n",
    "Runtime_dict = train(DNN, X_train, y_train.view(-1, 1), X_val, y_val.view(-1, 1), 'DNN_'+case+'.pth', n_epochs = 20000, batch_size=1024, lr = 5e-3)\n",
    "MSPE_dict= eval_model(DNN, 'DNN_'+case+'.pth', X_test, y_test.view(-1, 1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "c30c3bce-7174-4ae3-86ca-c29eeebb8a26",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "58.84652137756348\n",
      "tensor(0.7002)\n"
     ]
    }
   ],
   "source": [
    "print(Runtime_dict)\n",
    "print(MSPE_dict)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
