{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "635a1ed9",
   "metadata": {},
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: '/Users/a080528/Downloads/results/only_main-n150.json'",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mFileNotFoundError\u001b[39m                         Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 10\u001b[39m\n\u001b[32m      8\u001b[39m \u001b[38;5;66;03m# Open the JSON file in read mode ('r')\u001b[39;00m\n\u001b[32m      9\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m file_type:\n\u001b[32m---> \u001b[39m\u001b[32m10\u001b[39m     \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mfile_path\u001b[49m\u001b[43m+\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m+\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43m.json\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mr\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[32m     11\u001b[39m         \u001b[38;5;66;03m# Use json.load() to parse the JSON data from the file object\u001b[39;00m\n\u001b[32m     12\u001b[39m         data = json.load(f)\n\u001b[32m     14\u001b[39m     \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mData from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.json:\u001b[39m\u001b[33m\"\u001b[39m)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/sysapps/ubuntu-applications/pytorch/2.7.0/venv/lib/python3.13/site-packages/IPython/core/interactiveshell.py:326\u001b[39m, in \u001b[36m_modified_open\u001b[39m\u001b[34m(file, *args, **kwargs)\u001b[39m\n\u001b[32m    319\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m {\u001b[32m0\u001b[39m, \u001b[32m1\u001b[39m, \u001b[32m2\u001b[39m}:\n\u001b[32m    320\u001b[39m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[32m    321\u001b[39m         \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mIPython won\u001b[39m\u001b[33m'\u001b[39m\u001b[33mt let you open fd=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m by default \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m    322\u001b[39m         \u001b[33m\"\u001b[39m\u001b[33mas it is likely to crash IPython. If you know what you are doing, \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m    323\u001b[39m         \u001b[33m\"\u001b[39m\u001b[33myou can use builtins\u001b[39m\u001b[33m'\u001b[39m\u001b[33m open.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m    324\u001b[39m     )\n\u001b[32m--> \u001b[39m\u001b[32m326\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mio_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[31mFileNotFoundError\u001b[39m: [Errno 2] No such file or directory: '/Users/a080528/Downloads/results/only_main-n150.json'"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "# Define the path to your JSON file\n",
    "file_path = '/Users/a080528/Downloads/results/'  # Replace with the actual path to your file\n",
    "file_type = ['only_main-n150', 'weak_main-n150', 'inter_no_overlap-n150',\n",
    "              'inter_mild_overlap-n150', 'inter_strong_overlap-n150', 'only_inter-n150']\n",
    "\n",
    "# Open the JSON file in read mode ('r')\n",
    "for file in file_type:\n",
    "    with open(file_path+file+'.json', 'r') as f:\n",
    "        # Use json.load() to parse the JSON data from the file object\n",
    "        data = json.load(f)\n",
    "\n",
    "    print(f\"Data from {file}.json:\")\n",
    "    print(f\"MSE Mean (DNN): {np.mean(data['MSE_DNN']):.2f}, Std: {np.std(data['MSE_DNN']):.2f}\")\n",
    "    print(f\"MSE Mean (SDAM): {np.mean(data['MSE_ADB']):.2f}, Std: {np.std(data['MSE_ADB']):.2f}\")\n",
    "    #print(f\"Runtime Mean (DNN): {np.mean(data['Runtime_DNN'])}, Std: {np.std(data['Runtime_DNN'])}\")\n",
    "    #print(f\"Runtime Mean (SDAM): {np.mean(data['Runtime_ADN'])}, Std: {np.std(data['Runtime_ADN'])}\")\n",
    "\n",
    "    print('-----------------------------------------------------------')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "36655202",
   "metadata": {},
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: '/Users/a080528/Downloads/results/sdam-only_main-n300.json'",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mFileNotFoundError\u001b[39m                         Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 11\u001b[39m\n\u001b[32m      9\u001b[39m \u001b[38;5;66;03m# Open the JSON file in read mode ('r')\u001b[39;00m\n\u001b[32m     10\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m file_type:\n\u001b[32m---> \u001b[39m\u001b[32m11\u001b[39m     \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mfile_path\u001b[49m\u001b[43m+\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m+\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43m.json\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mr\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[32m     12\u001b[39m         \u001b[38;5;66;03m# Use json.load() to parse the JSON data from the file object\u001b[39;00m\n\u001b[32m     13\u001b[39m         data = json.load(f)\n\u001b[32m     15\u001b[39m     \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mData from \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.json:\u001b[39m\u001b[33m\"\u001b[39m)\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/sysapps/ubuntu-applications/pytorch/2.7.0/venv/lib/python3.13/site-packages/IPython/core/interactiveshell.py:326\u001b[39m, in \u001b[36m_modified_open\u001b[39m\u001b[34m(file, *args, **kwargs)\u001b[39m\n\u001b[32m    319\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m {\u001b[32m0\u001b[39m, \u001b[32m1\u001b[39m, \u001b[32m2\u001b[39m}:\n\u001b[32m    320\u001b[39m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[32m    321\u001b[39m         \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mIPython won\u001b[39m\u001b[33m'\u001b[39m\u001b[33mt let you open fd=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m by default \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m    322\u001b[39m         \u001b[33m\"\u001b[39m\u001b[33mas it is likely to crash IPython. If you know what you are doing, \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m    323\u001b[39m         \u001b[33m\"\u001b[39m\u001b[33myou can use builtins\u001b[39m\u001b[33m'\u001b[39m\u001b[33m open.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m    324\u001b[39m     )\n\u001b[32m--> \u001b[39m\u001b[32m326\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mio_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[31mFileNotFoundError\u001b[39m: [Errno 2] No such file or directory: '/Users/a080528/Downloads/results/sdam-only_main-n300.json'"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "# Define the path to your JSON file\n",
    "file_path = '/Users/a080528/Downloads/results/sdam-'  # Replace with the actual path to your file\n",
    "file_type = ['only_main-n300', 'weak_main-n300', 'inter_no_overlap-n300',\n",
    "              'inter_mild_overlap-n300', 'inter_strong_overlap-n300', 'only_inter-n300', \n",
    "              'weak_main-n450', 'inter_strong_overlap-n450', 'only_inter-n450']\n",
    "\n",
    "# Open the JSON file in read mode ('r')\n",
    "for file in file_type:\n",
    "    with open(file_path+file+'.json', 'r') as f:\n",
    "        # Use json.load() to parse the JSON data from the file object\n",
    "        data = json.load(f)\n",
    "\n",
    "    print(f\"Data from {file}.json:\")\n",
    "    print(f\"MSE Mean (DNN): {np.mean(data['MSE_DNN']):.2f}, Std: {np.std(data['MSE_DNN']):.2f}\")\n",
    "    print(f\"MSE Mean (SDAM): {np.mean(data['MSE_ADB']):.2f}, Std: {np.std(data['MSE_ADB']):.2f}\")\n",
    "    #print(f\"Runtime Mean (DNN): {np.mean(data['Runtime_DNN'])}, Std: {np.std(data['Runtime_DNN'])}\")\n",
    "    #print(f\"Runtime Mean (SDAM): {np.mean(data['Runtime_ADN'])}, Std: {np.std(data['Runtime_ADN'])}\")\n",
    "\n",
    "    print('-----------------------------------------------------------')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e5fb26b",
   "metadata": {},
   "source": [
    "## Comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "47c31127",
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'fastsparsegams'",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mModuleNotFoundError\u001b[39m                       Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 10\u001b[39m\n\u001b[32m      8\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpygam\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m LinearGAM, s, te\n\u001b[32m      9\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01msklearn\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mmetrics\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m mean_squared_error\n\u001b[32m---> \u001b[39m\u001b[32m10\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mfastsparsegams\u001b[39;00m\n\u001b[32m     12\u001b[39m \u001b[38;5;28;01mclass\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mFeatureNet\u001b[39;00m(nn.Module):\n\u001b[32m     13\u001b[39m     \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, input_dim, hidden_dims):\n",
      "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'fastsparsegams'"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "import numpy as np\n",
    "import time\n",
    "import copy\n",
    "import numpy as np\n",
    "from pygam import LinearGAM, s, te\n",
    "from sklearn.metrics import mean_squared_error\n",
    "import fastsparsegams\n",
    "\n",
    "class FeatureNet(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dims):\n",
    "        \"\"\"\n",
    "        Flexible feature/interaction-specific network\n",
    "        Args:\n",
    "            input_dim: number of input dimensions for this subnetwork\n",
    "            hidden_dims: list of hidden layer dimensions (e.g., [6, 3] for input->6->3->1)\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        layers = []\n",
    "        in_dim = input_dim\n",
    "        \n",
    "        # Hidden layers\n",
    "        for h_dim in hidden_dims:\n",
    "            layers.extend([\n",
    "                nn.Linear(in_dim, h_dim),\n",
    "                nn.ReLU()\n",
    "            ])\n",
    "            in_dim = h_dim\n",
    "            \n",
    "        # Final output layer (scalar output)\n",
    "        layers.append(nn.Linear(in_dim, 1))\n",
    "        self.net = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "\n",
    "class AdditiveModel(nn.Module):\n",
    "    def __init__(self, index_list, hidden_dims, output_dim=1):\n",
    "        \"\"\"\n",
    "        Additive model with main effects + interaction terms\n",
    "        Args:\n",
    "            index_list: list of lists specifying feature groups, \n",
    "                        e.g. [[0], [1], [2], [2,3]]\n",
    "            hidden_dims: hidden layer sizes for each subnetwork\n",
    "            output_dim: dimension of model output\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        self.index_list = index_list\n",
    "        \n",
    "        # Build one FeatureNet per group in index_list\n",
    "        self.feature_nets = nn.ModuleList([\n",
    "            FeatureNet(len(indices), hidden_dims) \n",
    "            for indices in index_list\n",
    "        ])\n",
    "        \n",
    "        # Linear combiner (without bias) to sum contributions\n",
    "        self.combiner = nn.Linear(len(index_list), output_dim, bias=False)\n",
    "        self.hook = {}\n",
    "\n",
    "    def forward(self, X):\n",
    "        \"\"\"\n",
    "        Forward pass through all subnetworks\n",
    "        Args:\n",
    "            X: input tensor of shape [batch_size, num_features]\n",
    "        \"\"\"\n",
    "        individual_outputs = []\n",
    "        \n",
    "        for indices, net in zip(self.index_list, self.feature_nets):\n",
    "            # Select relevant columns (keep 2D)\n",
    "            x_sub = X[:, indices]\n",
    "            out = net(x_sub)  # [batch_size, 1]\n",
    "            individual_outputs.append(out)\n",
    "        \n",
    "        combined = torch.cat(individual_outputs, dim=1)  # [batch_size, num_subnets]\n",
    "        self.hook['acomp'] = combined\n",
    "        return self.combiner(combined)\n",
    "\n",
    "class EarlyStopper:\n",
    "    def __init__(self, patience=1, min_delta=0):\n",
    "        self.patience = patience\n",
    "        self.min_delta = min_delta\n",
    "        self.counter = 0\n",
    "        self.min_validation_loss = float('inf')\n",
    "\n",
    "    def early_stop(self, validation_loss):\n",
    "        if validation_loss < self.min_validation_loss:\n",
    "            self.min_validation_loss = validation_loss\n",
    "            self.counter = 0\n",
    "        elif validation_loss > (self.min_validation_loss + self.min_delta):\n",
    "            self.counter += 1\n",
    "            if self.counter >= self.patience:\n",
    "                return True\n",
    "        return False\n",
    "    \n",
    "def train(model, X_train, y_train, X_val, y_val, file_path, n_epochs = 10000, batch_size=64, lr=1e-2, pt = 50):\n",
    "    loss_fn = nn.MSELoss()\n",
    "    optimizer = optim.Adam(model.parameters(), lr = lr)\n",
    "    scheduler = ReduceLROnPlateau(optimizer, 'min')\n",
    "    early_stopping = EarlyStopper(patience = pt)\n",
    "\n",
    "    # Training loop\n",
    "    n_epochs = n_epochs\n",
    "    batch_size = X_train.size()[0]\n",
    "    batch_start = torch.arange(0, len(X_train), batch_size)\n",
    "\n",
    "    best_mse = float('inf')\n",
    "    best_weights = None\n",
    "\n",
    "    start_time = time.time()\n",
    "    patient = 0\n",
    "    # Training loop\n",
    "\n",
    "    for epoch in range(n_epochs):\n",
    "        model.train()\n",
    "        \n",
    "        for start in batch_start:\n",
    "            X_batch = X_train[start:start+batch_size]\n",
    "            y_batch = y_train[start:start+batch_size]\n",
    "            \n",
    "            # Forward pass\n",
    "            y_pred = model(X_batch)\n",
    "            loss = loss_fn(y_pred, y_batch)\n",
    "            \n",
    "            # Backward pass\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "        \n",
    "        # Evaluate model on test set at the end of each epoch\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            y_pred = model(X_val)\n",
    "            val_loss = loss_fn(y_pred, y_val)\n",
    "            val_loss = float(val_loss)\n",
    "            scheduler.step(val_loss)\n",
    "            \n",
    "            if early_stopping.early_stop(val_loss):\n",
    "                print(f\"Early Stop at Epoch {epoch}\")\n",
    "                break\n",
    "            \n",
    "            if val_loss < best_mse:\n",
    "                best_mse = val_loss\n",
    "                best_weights = copy.deepcopy(model.state_dict())\n",
    "            \n",
    "        '''  \n",
    "        if patient == pt:\n",
    "            print(f\"Early Stop at Epoch {epoch}\")\n",
    "            break\n",
    "        '''\n",
    "        \n",
    "        if (epoch+1) % 1000 == 0:\n",
    "            print(f\"Epoch {epoch+1}, MSE: {val_loss}\")\n",
    "\n",
    "    end_time = time.time()\n",
    "    Time_consumption = end_time - start_time\n",
    "    torch.save(best_weights, file_path)\n",
    "\n",
    "    return Time_consumption\n",
    "\n",
    "def eval_model(model, path, X, y):\n",
    "    model.load_state_dict(torch.load(path, weights_only=True))\n",
    "    model.eval()\n",
    "\n",
    "    loss_fn = nn.MSELoss()\n",
    "    with torch.no_grad():\n",
    "        y_pred = model(X)\n",
    "        mse = loss_fn(y, y_pred)\n",
    "\n",
    "    return mse"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a1e3e35",
   "metadata": {},
   "source": [
    "### SDAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2fd29da0",
   "metadata": {},
   "outputs": [],
   "source": [
    "## saving data to R for sodavis\n",
    "\n",
    "name = ['only_main_data', 'weak_main_data', 'inter_no_overlap_data', 'inter_mild_overlap_data', 'inter_strong_overlap_data', 'only_inter_data']\n",
    "\n",
    "for n in name:\n",
    "    _dict = torch.load('/Users/a080528/Downloads/data/'+ n + '.pt', weights_only= True)\n",
    "    X_train = np.array(_dict['X_train'])\n",
    "    y_train = np.array(_dict['y_train'])\n",
    "    X_val = np.array(_dict['X_valid'])\n",
    "    y_val = np.array(_dict['y_valid'])\n",
    "    X_test = np.array(_dict['X_test'])\n",
    "    y_test = np.array(_dict['y_test'])\n",
    "    break\n",
    "\n",
    "##saving data to R for sodavis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "294de0af",
   "metadata": {},
   "outputs": [],
   "source": [
    "MSE_result = torch.zeros((100))\n",
    "for r in range(100):\n",
    "    active_idx = [[0], [1], [2], [3]]\n",
    "    ADNN = AdditiveModel(\n",
    "                index_list= active_idx,\n",
    "                hidden_dims= [5, 3],\n",
    "                output_dim=1\n",
    "                )\n",
    "\n",
    "    _ = train(ADNN, torch.tensor(X_train[r][:, [0, 1, 2, 3]]), torch.tensor(y_train[r]).view(-1, 1), torch.tensor(X_val[r][:, [0, 1, 2, 3]]), torch.tensor(y_val[r]).view(-1, 1), './tests/ADM'+str(r+1)+'.pth', n_epochs = 5000, batch_size=32, lr = 1e-3)\n",
    "    MSE_result[r] = eval_model(ADNN, './tests/ADM'+str(r+1)+'.pth', torch.tensor(X_test[r][:, [0, 1, 2, 3]]), torch.tensor(y_test[r]).view(-1, 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25e12f8d",
   "metadata": {},
   "source": [
    "### fSpAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb3d263a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "only_main300_data 5.570812944224448 0.30685465876180923\n",
      "weak_main300_data 3.043623159065463 0.15660049508192056\n",
      "inter_no_overlap300_data 3.3718659235906836 0.15074308410917836\n",
      "inter_mild_overlap300_data 3.3192160491782734 0.17064745963754333\n",
      "inter_strong_overlap300_data 3.4476572753770824 0.1579674414623089\n",
      "only_inter300_data 0.6019879660073071 0.03111568000603531\n"
     ]
    }
   ],
   "source": [
    "name = ['only_main_data', 'weak_main_data', 'inter_no_overlap_data', 'inter_mild_overlap_data', 'inter_strong_overlap_data', 'only_inter_data']\n",
    "name = ['only_main300_data', 'weak_main300_data', 'inter_no_overlap300_data', 'inter_mild_overlap300_data', 'inter_strong_overlap300_data', 'only_inter300_data']\n",
    "\n",
    "for n in name:\n",
    "    _dict = torch.load('/Users/a080528/Downloads/data/'+ n + '.pt', weights_only= True)\n",
    "    X_train = np.array(_dict['X_train'])\n",
    "    y_train = np.array(_dict['y_train'])\n",
    "    X_test = np.array(_dict['X_test'])\n",
    "    y_test = np.array(_dict['y_test'])\n",
    "\n",
    "    mse = []\n",
    "    for r in range(X_train.shape[0]):\n",
    "\n",
    "        fit_model = fastsparsegams.fit(X_train[r].astype(np.float64), y_train[r].astype(np.float64), penalty=\"L0\", max_support_size=20)\n",
    "        y_pred = (fit_model.predict(x=X_test[r], lambda_0=0.032715, gamma=0))\n",
    "\n",
    "        mse.append(mean_squared_error(y_pred, y_test[r]))\n",
    "\n",
    "    print(n, np.mean(mse), np.std(mse))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d187972",
   "metadata": {},
   "source": [
    "### LASSO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "3433cb9a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "only_main600_data\n",
      "3.044720260302226\n",
      "0.16528252915237834\n",
      "weak_main600_data\n",
      "2.4027133226394652\n",
      "0.10743447968762997\n",
      "inter_no_overlap600_data\n",
      "2.7161594470342\n",
      "0.14409619129407492\n",
      "inter_mild_overlap600_data\n",
      "2.614427367846171\n",
      "0.12988584150805557\n",
      "inter_strong_overlap600_data\n",
      "2.7607612133026125\n",
      "0.1258967548954408\n",
      "only_inter600_data\n",
      "0.3927682211001714\n",
      "0.02169689495841283\n"
     ]
    }
   ],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "from sklearn.preprocessing import PolynomialFeatures, StandardScaler\n",
    "from sklearn.linear_model import LassoCV\n",
    "\n",
    "r = 30\n",
    "lasso_loss = np.zeros((r))\n",
    "name = ['only_main600_data', 'weak_main600_data', 'inter_no_overlap600_data', 'inter_mild_overlap600_data'\n",
    "        , 'inter_strong_overlap600_data', 'only_inter600_data']\n",
    "\n",
    "for n in name:\n",
    "    _dict = torch.load('../data/' + n + '.pt', weights_only= True)\n",
    "    X_train = np.array(_dict['X_train'])\n",
    "    y_train = np.array(_dict['y_train'])\n",
    "    X_test = np.array(_dict['X_test'])\n",
    "    y_test = np.array(_dict['y_test'])\n",
    "\n",
    "    for i in range(r):\n",
    "        poly = PolynomialFeatures(degree=2, include_bias=False)\n",
    "        X_basis = poly.fit_transform(X_train[i][:450, :])\n",
    "        X_basis_test = poly.fit_transform(X_test[i])\n",
    "\n",
    "        lasso = LassoCV(cv=5)\n",
    "        lasso.fit(X_basis, y_train[i][:450])\n",
    "        test = lasso.predict(X_basis_test)\n",
    "        lasso_loss[i] = mean_squared_error(y_test[i], test)\n",
    "\n",
    "    print(n)\n",
    "    print(lasso_loss.mean())\n",
    "    print(lasso_loss.std())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8840c798",
   "metadata": {},
   "source": [
    "## TPR/ FPR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "641e30d2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------------------\n",
      "Main Effect Only\n"
     ]
    },
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: '/Users/a080528/Downloads/only_main.txt'",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mFileNotFoundError\u001b[39m                         Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 50\u001b[39m\n\u001b[32m     48\u001b[39m true_i = []\n\u001b[32m     49\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m'\u001b[39m\u001b[33mMain Effect Only\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m50\u001b[39m \u001b[43msummary\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrue_m\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrue_i\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_feature\u001b[49m\u001b[43m)\u001b[49m \n\u001b[32m     51\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m'\u001b[39m\u001b[33m-------------------------------\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m     53\u001b[39m file_type = \u001b[33m'\u001b[39m\u001b[33m/Users/a080528/Downloads/weak_main.txt\u001b[39m\u001b[33m'\u001b[39m\n",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 22\u001b[39m, in \u001b[36msummary\u001b[39m\u001b[34m(file_path, true_main, true_inter, p)\u001b[39m\n\u001b[32m     18\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34msummary\u001b[39m(file_path, true_main, true_inter, p):\n\u001b[32m---> \u001b[39m\u001b[32m22\u001b[39m     \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mfile_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mr\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[32m     23\u001b[39m         lines = f.readlines() \u001b[38;5;66;03m# Reads all lines into a list\u001b[39;00m\n\u001b[32m     25\u001b[39m         TPR = []\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/sysapps/ubuntu-applications/pytorch/2.7.0/venv/lib/python3.13/site-packages/IPython/core/interactiveshell.py:326\u001b[39m, in \u001b[36m_modified_open\u001b[39m\u001b[34m(file, *args, **kwargs)\u001b[39m\n\u001b[32m    319\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m file \u001b[38;5;129;01min\u001b[39;00m {\u001b[32m0\u001b[39m, \u001b[32m1\u001b[39m, \u001b[32m2\u001b[39m}:\n\u001b[32m    320\u001b[39m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[32m    321\u001b[39m         \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mIPython won\u001b[39m\u001b[33m'\u001b[39m\u001b[33mt let you open fd=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m by default \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m    322\u001b[39m         \u001b[33m\"\u001b[39m\u001b[33mas it is likely to crash IPython. If you know what you are doing, \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m    323\u001b[39m         \u001b[33m\"\u001b[39m\u001b[33myou can use builtins\u001b[39m\u001b[33m'\u001b[39m\u001b[33m open.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m    324\u001b[39m     )\n\u001b[32m--> \u001b[39m\u001b[32m326\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mio_open\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[31mFileNotFoundError\u001b[39m: [Errno 2] No such file or directory: '/Users/a080528/Downloads/only_main.txt'"
     ]
    }
   ],
   "source": [
    "import ast\n",
    "\n",
    "def list2set(arr):\n",
    "    Set = set()\n",
    "    for item in arr:\n",
    "        Set.add(item[0]) if len(item) == 1 else Set.add(tuple(item))\n",
    "    \n",
    "    return Set\n",
    "\n",
    "def detection(tm, ti, arr, d):\n",
    "    trueset = list2set(tm+ti)\n",
    "    predset = list2set(arr)\n",
    "\n",
    "    TPR = len(trueset & predset)/ len(trueset)\n",
    "    FPR = len((predset - trueset))/ ((d*(d-1)/2)+d-len(trueset))\n",
    "    return TPR, FPR\n",
    "        \n",
    "def summary(file_path, true_main, true_inter, p):\n",
    "\n",
    "    \n",
    "\n",
    "    with open(file_path, 'r') as f:\n",
    "        lines = f.readlines() # Reads all lines into a list\n",
    "        \n",
    "        TPR = []\n",
    "        FPR = []\n",
    "        t = 0\n",
    "        for l in lines:\n",
    "            if 'Active' in l:\n",
    "                t+= 1\n",
    "                start_index = l.find(\"[[\")\n",
    "                array_string = l[start_index:].strip()\n",
    "                result_array = ast.literal_eval(array_string)\n",
    "                tvalue, fvalue = detection(true_main, true_inter, result_array, p)\n",
    "                TPR.append(tvalue)\n",
    "                FPR.append(fvalue)\n",
    "\n",
    "        print(f\"TPR Mean (DNN): {np.mean(TPR):.4f}, Std: {np.std(TPR):.4f}\")\n",
    "        print(f\"FPR Mean (SDAM): {np.mean(FPR):.4f}, Std: {np.std(FPR):.4f}\")\n",
    "\n",
    "    return \n",
    "\n",
    "num_feature = 150\n",
    "\n",
    "print('-------------------------------')\n",
    "file_path = '/Users/a080528/Downloads/only_main.txt'\n",
    "true_m = [[0], [1], [2], [3]]\n",
    "true_i = []\n",
    "print('Main Effect Only')\n",
    "summary(file_path, true_m, true_i, num_feature) \n",
    "print('-------------------------------')\n",
    "\n",
    "file_type = '/Users/a080528/Downloads/weak_main.txt'\n",
    "true_m = [[0], [1], [2], [3]]\n",
    "print('Weak Main Effect Only')\n",
    "summary(file_path, true_m, true_i, num_feature) \n",
    "print('-------------------------------')\n",
    "\n",
    "file_type = '/Users/a080528/Downloads/inter_no_overlap.txt'\n",
    "true_m = [[0], [1], [2]]\n",
    "true_i = [[3, 4]]\n",
    "print('Inter No Overlap')\n",
    "summary(file_path, true_m, true_i, num_feature) \n",
    "print('-------------------------------')\n",
    "\n",
    "file_type = '/Users/a080528/Downloads/inter_mild_overlap.txt'\n",
    "true_m = [[0], [1], [2]]\n",
    "true_i = [[2, 3]]\n",
    "print('Inter Mild Overlap')\n",
    "summary(file_path, true_m, true_i, num_feature) \n",
    "print('-------------------------------')\n",
    "\n",
    "file_type = '/Users/a080528/Downloads/inter_strong_overlap.txt'\n",
    "true_m = [[0], [1], [2]]\n",
    "true_i = [[1, 2]]\n",
    "print('Inter Strong Overlap')\n",
    "summary(file_path, true_m, true_i, num_feature) \n",
    "print('-------------------------------')\n",
    "\n",
    "file_type = '/Users/a080528/Downloads/only_inter.txt'\n",
    "true_i = [[0, 1], [2, 3]]\n",
    "print('Only Inter Effects')\n",
    "summary(file_path, true_m, true_i, num_feature) \n",
    "print('-------------------------------')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 201,
   "id": "fee98739",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-------------------------------\n",
      "Main Effect Only\n",
      "TPR Mean (DNN): 1.0000, Std: 0.0000\n",
      "FPR Mean (SDAM): 0.0000, Std: 0.0000\n",
      "-------------------------------\n",
      "Weak Main Effect Only\n",
      "TPR Mean (DNN): 1.0000, Std: 0.0000\n",
      "FPR Mean (SDAM): 0.0000, Std: 0.0000\n",
      "-------------------------------\n",
      "Inter No Overlap\n",
      "TPR Mean (DNN): 0.7500, Std: 0.0000\n",
      "FPR Mean (SDAM): 0.0001, Std: 0.0000\n",
      "-------------------------------\n",
      "Inter Mild Overlap\n",
      "TPR Mean (DNN): 0.7500, Std: 0.0000\n",
      "FPR Mean (SDAM): 0.0001, Std: 0.0000\n",
      "-------------------------------\n",
      "Inter Strong Overlap\n",
      "TPR Mean (DNN): 0.7550, Std: 0.0350\n",
      "FPR Mean (SDAM): 0.0001, Std: 0.0000\n",
      "-------------------------------\n",
      "Only Inter Effects\n",
      "TPR Mean (DNN): 0.6000, Std: 0.0000\n",
      "FPR Mean (SDAM): 0.0001, Std: 0.0000\n",
      "-------------------------------\n"
     ]
    }
   ],
   "source": [
    "num_feature = 150\n",
    "\n",
    "print('-------------------------------')\n",
    "file_path = '/Users/a080528/Downloads/results/only_main_n300.txt'\n",
    "true_m = [[0], [1], [2], [3]]\n",
    "true_i = []\n",
    "print('Main Effect Only')\n",
    "summary(file_path, true_m, true_i, num_feature) \n",
    "print('-------------------------------')\n",
    "\n",
    "file_type = '/Users/a080528/Downloads/results/weak_main_n300.txt'\n",
    "true_m = [[0], [1], [2], [3]]\n",
    "print('Weak Main Effect Only')\n",
    "summary(file_path, true_m, true_i, num_feature) \n",
    "print('-------------------------------')\n",
    "\n",
    "file_type = '/Users/a080528/Downloads/results/inter_no_overlap_n300.txt'\n",
    "true_m = [[0], [1], [2]]\n",
    "true_i = [[3, 4]]\n",
    "print('Inter No Overlap')\n",
    "summary(file_path, true_m, true_i, num_feature) \n",
    "print('-------------------------------')\n",
    "\n",
    "file_type = '/Users/a080528/Downloads/results/inter_mild_overlap_n300.txt'\n",
    "true_m = [[0], [1], [2]]\n",
    "true_i = [[2, 3]]\n",
    "print('Inter Mild Overlap')\n",
    "summary(file_path, true_m, true_i, num_feature) \n",
    "print('-------------------------------')\n",
    "\n",
    "file_type = '/Users/a080528/Downloads/results/inter_strong_overlap_n300.txt'\n",
    "true_m = [[0], [1], [2]]\n",
    "true_i = [[1, 2]]\n",
    "print('Inter Strong Overlap')\n",
    "summary(file_path, true_m, true_i, num_feature) \n",
    "print('-------------------------------')\n",
    "\n",
    "file_type = '/Users/a080528/Downloads/results/only_inter_n300.txt'\n",
    "true_i = [[0, 1], [2, 3]]\n",
    "print('Only Inter Effects')\n",
    "summary(file_path, true_m, true_i, num_feature) \n",
    "print('-------------------------------')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31b1882d",
   "metadata": {},
   "source": [
    "### Group-LASSO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd477da3",
   "metadata": {},
   "outputs": [],
   "source": [
    "name = ['only_main_data', 'weak_main_data', 'inter_no_overlap_data', 'inter_mild_overlap_data', 'inter_strong_overlap_data', 'only_inter_data']\n",
    "name = ['only_main_data']\n",
    "\n",
    "for n in name:\n",
    "    _dict = torch.load('/Users/a080528/Downloads/'+ n + '.pt', weights_only= True)\n",
    "    X_train = np.array(_dict['X_train'])\n",
    "    y_train = np.array(_dict['y_train'])\n",
    "    X_test = np.array(_dict['X_test'])\n",
    "    y_test = np.array(_dict['y_test'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "438037d0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/a080528/miniconda3/envs/torch/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.preprocessing import SplineTransformer, StandardScaler, PolynomialFeatures\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.linear_model import LassoCV, Lasso\n",
    "from group_lasso import GroupLasso\n",
    "import matplotlib.pyplot as plt\n",
    "import xgboost as xgb\n",
    "import shap\n",
    "import warnings\n",
    "import torch\n",
    "from itertools import combinations, combinations_with_replacement\n",
    "\n",
    "class AdditiveInteractionSelector:\n",
    "    \"\"\"\n",
    "    Fit additive models with candidate main effects and interactions\n",
    "    using spline basis expansions + group sparsity (group lasso).\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, n_splines=10, spline_degree=3, include_intercept=False,\n",
    "                 interaction_splines=10, random_state=0):\n",
    "        self.n_splines = n_splines\n",
    "        self.spline_degree = spline_degree\n",
    "        self.include_intercept = include_intercept\n",
    "        self.interaction_splines = interaction_splines\n",
    "        self.random_state = random_state\n",
    "\n",
    "        # Internal storage\n",
    "        self.groups = []\n",
    "        self.group_names = []\n",
    "        self.scaler = None\n",
    "        self.model = None\n",
    "        self.group_norms_ = None\n",
    "        self.design_matrix_ = None\n",
    "\n",
    "    # -----------------------------\n",
    "    # Basis construction utilities\n",
    "    # -----------------------------\n",
    "    def _build_univariate_basis(self, x):\n",
    "        \"\"\"Build spline basis for one variable.\"\"\"\n",
    "        x = np.asarray(x).reshape(-1, 1)\n",
    "        sp = SplineTransformer(\n",
    "            degree=self.spline_degree,\n",
    "            n_knots=self.n_splines,\n",
    "            include_bias=self.include_intercept\n",
    "        )\n",
    "        return sp.fit_transform(x)\n",
    "\n",
    "    def _build_bivariate_basis(self, x1, x2):\n",
    "        \"\"\"Build tensor product spline basis for interaction.\"\"\"\n",
    "        B1 = self._build_univariate_basis(x1)\n",
    "        B2 = self._build_univariate_basis(x2)\n",
    "        # Tensor product\n",
    "        return np.einsum(\"ij,ik->ijk\", B1, B2).reshape(len(x1), -1)\n",
    "\n",
    "    def _build_design(self, X_df, interactions=None):\n",
    "        \"\"\"Construct design matrix with groups for univariates and interactions.\"\"\"\n",
    "        blocks, self.groups, self.group_names = [], [], []\n",
    "        col_idx = 0\n",
    "\n",
    "        # Main effects\n",
    "        for col in X_df.columns:\n",
    "            B = self._build_univariate_basis(X_df[col].values)\n",
    "            blocks.append(B)\n",
    "            self.groups.append(list(range(col_idx, col_idx + B.shape[1])))\n",
    "            self.group_names.append((col,))\n",
    "            col_idx += B.shape[1]\n",
    "\n",
    "        # Interactions\n",
    "        if interactions:\n",
    "            for a, b in interactions:\n",
    "                Bt = self._build_bivariate_basis(X_df[a].values, X_df[b].values)\n",
    "                blocks.append(Bt)\n",
    "                self.groups.append(list(range(col_idx, col_idx + Bt.shape[1])))\n",
    "                self.group_names.append((a, b))\n",
    "                col_idx += Bt.shape[1]\n",
    "\n",
    "        self.design_matrix_ = np.hstack(blocks)\n",
    "        return self.design_matrix_\n",
    "\n",
    "    # -----------------------------\n",
    "    # Fitting\n",
    "    # -----------------------------\n",
    "    def fit(self, X_df, y, interactions=None, cv=5, HAS_GROUP_LASSO=True):\n",
    "        \"\"\"\n",
    "        Fit model with group lasso (preferred) or fallback to plain Lasso.\n",
    "        \"\"\"\n",
    "        X = self._build_design(X_df, interactions)\n",
    "        self.scaler = StandardScaler()\n",
    "        Xs = self.scaler.fit_transform(X)\n",
    "\n",
    "        if HAS_GROUP_LASSO:\n",
    "            # Build group vector\n",
    "            col_to_group = np.zeros(X.shape[1], dtype=int)\n",
    "            for gid, idxs in enumerate(self.groups):\n",
    "                col_to_group[idxs] = gid\n",
    "\n",
    "            # Cross-validate group lasso penalty\n",
    "            lambdas = np.logspace(-3, 1, 10)\n",
    "            best_score, best_model = -np.inf, None\n",
    "            kf = KFold(n_splits=cv, shuffle=True, random_state=self.random_state)\n",
    "\n",
    "            for lam in lambdas:\n",
    "                scores = []\n",
    "                for tr, va in kf.split(Xs):\n",
    "                    gl = GroupLasso(\n",
    "                        groups=col_to_group,\n",
    "                        group_reg=lam, l1_reg=0.0,\n",
    "                        scale_reg=\"group_size\",\n",
    "                        supress_warning=True,\n",
    "                        n_iter=2000, tol=1e-3\n",
    "                    )\n",
    "                    gl.fit(Xs[tr], y[tr])\n",
    "                    scores.append(gl.score(Xs[va], y[va]))\n",
    "                if np.mean(scores) > best_score:\n",
    "                    best_score = np.mean(scores)\n",
    "                    best_model = GroupLasso(\n",
    "                        groups=col_to_group,\n",
    "                        group_reg=lam, l1_reg=0.0,\n",
    "                        scale_reg=\"group_size\",\n",
    "                        supress_warning=True,\n",
    "                        n_iter=2000, tol=1e-3\n",
    "                    )\n",
    "                    best_model.fit(Xs, y)\n",
    "\n",
    "            self.model = best_model\n",
    "            coefs = self.model.coef_.ravel()\n",
    "\n",
    "        else:\n",
    "            # Fallback to plain Lasso\n",
    "            lasso = LassoCV(cv=cv).fit(Xs, y)\n",
    "            self.model = lasso\n",
    "            coefs = lasso.coef_\n",
    "\n",
    "        # Compute group norms\n",
    "        self.group_norms_ = [\n",
    "            np.linalg.norm(coefs[idxs], ord=2) for idxs in self.groups\n",
    "        ]\n",
    "        return self\n",
    "\n",
    "    # -----------------------------\n",
    "    # Reporting\n",
    "    # -----------------------------\n",
    "    def get_group_importance(self):\n",
    "        \"\"\"Return DataFrame of group names and their norms.\"\"\"\n",
    "        return pd.DataFrame({\n",
    "            \"group\": self.group_names,\n",
    "            \"norm\": self.group_norms_\n",
    "        }).sort_values(\"norm\", ascending=False).reset_index(drop=True)\n",
    "\n",
    "    def get_important_groups(self, threshold=0.1):\n",
    "        \"\"\"Return groups with norms above threshold.\"\"\"\n",
    "        selected = []\n",
    "        for tup, val in zip(self.group_names, self.group_norms_):\n",
    "            if val > threshold:\n",
    "                indices = [int(s[1:]) - 1 for s in tup]  # convert \"x1\" → 0\n",
    "                selected.append(indices)\n",
    "        return selected\n",
    "    \n",
    "    def summary(self):\n",
    "        \"\"\"Print ranked group importance.\"\"\"\n",
    "        df = self.get_group_importance()\n",
    "        print(\"Group importance (higher = more important):\")\n",
    "        print(df)\n",
    "\n",
    "\n",
    "def extract_active_features(X: torch.tensor, active_idx: list[int]) -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    Extract active features from X based on active indices.\n",
    "    \n",
    "    Parameters\n",
    "    ----------\n",
    "    X : np.ndarray\n",
    "        Data matrix of shape (n_samples, n_features)\n",
    "    active_idx : list[int]\n",
    "        Indices of active features\n",
    "    \n",
    "    Returns\n",
    "    -------\n",
    "    pd.DataFrame\n",
    "        DataFrame with columns named x1, x2, ..., for active features\n",
    "    \"\"\"\n",
    "    data = {f\"x{idx+1}\": X[:, idx] for idx in active_idx}\n",
    "    return pd.DataFrame(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d496105f",
   "metadata": {},
   "outputs": [],
   "source": [
    "threshold = 0.1\n",
    "main = [i for i in range(X_train[0].shape[1])]\n",
    "opt_df = extract_active_features(X_train[0], main)\n",
    "interactions = list(combinations(list(opt_df.keys()), 2))\n",
    "\n",
    "selector = AdditiveInteractionSelector(n_splines = 5, interaction_splines = 5)\n",
    "selector.fit(opt_df, y_train[0], interactions = None)\n",
    "\n",
    "GroupLasso_config = selector.get_important_groups(threshold=threshold)\n",
    "main_config = selector.get_important_groups(threshold=threshold)\n",
    "main_config = [GroupLasso_config[i][0]for i in range(len(GroupLasso_config))]\n",
    "opt_df = extract_active_features(X_train[0], main_config)\n",
    "interactions = list(combinations(list(opt_df.keys()), 2))\n",
    "\n",
    "selector.fit(opt_df, y_train[0], interactions = interactions)\n",
    "Final_config = selector.get_important_groups(threshold=threshold)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "id": "268336ba",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[0], [1], [2], [3], [0, 2], [1, 2]]"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Final_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f2bc45d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60e57e72",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "9b2ed170",
   "metadata": {},
   "source": [
    "### LASSO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6f5453f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50, 300, 150)"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "32bedf80",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1000,)"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_test[0].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb577a77",
   "metadata": {},
   "source": [
    "### LASSONET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "656bbe0d-ab9f-4e9f-8025-b6ae2386784b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100, 150, 150)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "name = ['only_main_data', 'weak_main_data', 'inter_no_overlap_data', 'inter_mild_overlap_data', 'inter_strong_overlap_data', 'only_inter_data']\n",
    "name = ['inter_no_overlap_data', 'inter_mild_overlap_data', 'inter_strong_overlap_data', 'only_inter_data']\n",
    "\n",
    "for n in name:\n",
    "    print(n)\n",
    "    _dict = torch.load('/Users/a080528/Downloads/data/'+ n + '.pt', weights_only= True)\n",
    "    X_train = np.array(_dict['X_train'])\n",
    "    y_train = np.array(_dict['y_train'])\n",
    "    X_test = np.array(_dict['X_test'])\n",
    "    y_test = np.array(_dict['y_test'])\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c6fab77-dbc4-4e35-bf89-a116e9759063",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b5286411",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-1\n",
      "19\n",
      "39\n",
      "TPR Mean (DNN): 0.6100, Std: 0.1241\n",
      "FPR Mean (SDAM): 0.0129, Std: 0.0091\n",
      "-1\n",
      "19\n",
      "39\n",
      "TPR Mean (DNN): 0.4550, Std: 0.1083\n",
      "FPR Mean (SDAM): 0.0168, Std: 0.0083\n",
      "-1\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mKeyboardInterrupt\u001b[39m                         Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[6]\u001b[39m\u001b[32m, line 49\u001b[39m\n\u001b[32m     38\u001b[39m X_basis = poly.fit_transform(X_train[i])\n\u001b[32m     40\u001b[39m model = LassoNetRegressorCV(\n\u001b[32m     41\u001b[39m     hidden_dims=(\u001b[32m10\u001b[39m, \u001b[32m10\u001b[39m),   \u001b[38;5;66;03m# neural net architecture\u001b[39;00m\n\u001b[32m     42\u001b[39m     M=\u001b[32m1\u001b[39m,                   \u001b[38;5;66;03m# hierarchy parameter (linear vs nonlinear strength)\u001b[39;00m\n\u001b[32m   (...)\u001b[39m\u001b[32m     47\u001b[39m     torch_seed=\u001b[32m42\u001b[39m,\n\u001b[32m     48\u001b[39m ) \n\u001b[32m---> \u001b[39m\u001b[32m49\u001b[39m path = \u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_basis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m     50\u001b[39m pdim = X_train[i].shape[\u001b[32m1\u001b[39m]\n\u001b[32m     51\u001b[39m indice = [k \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(pdim)]\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.local/lib/python3.13/site-packages/lassonet/interfaces.py:908\u001b[39m, in \u001b[36mBaseLassoNetCV.path\u001b[39m\u001b[34m(self, X, y, return_state_dicts)\u001b[39m\n\u001b[32m    905\u001b[39m     \u001b[38;5;28mself\u001b[39m.lambda_start_ = \u001b[38;5;28mself\u001b[39m.lambdas_[\u001b[32m0\u001b[39m]\n\u001b[32m    907\u001b[39m \u001b[38;5;66;03m# train with the chosen lambda sequence\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m908\u001b[39m path = \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m    909\u001b[39m \u001b[43m    \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    910\u001b[39m \u001b[43m    \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    911\u001b[39m \u001b[43m    \u001b[49m\u001b[43mlambda_seq\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mlambdas_\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mbest_lambda_idx\u001b[49m\u001b[43m \u001b[49m\u001b[43m+\u001b[49m\u001b[43m \u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    912\u001b[39m \u001b[43m    \u001b[49m\u001b[43mreturn_state_dicts\u001b[49m\u001b[43m=\u001b[49m\u001b[43mreturn_state_dicts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    913\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    914\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m, LassoNetCoxRegressor) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m path[-\u001b[32m1\u001b[39m].selected.any():\n\u001b[32m    915\u001b[39m     \u001b[38;5;66;03m# condition to retrain and avoid having 0 feature which gives score 0\u001b[39;00m\n\u001b[32m    916\u001b[39m     \u001b[38;5;66;03m# TODO: handle backtrack in path even when return_state_dicts=False\u001b[39;00m\n\u001b[32m    917\u001b[39m     path = \u001b[38;5;28msuper\u001b[39m().path(\n\u001b[32m    918\u001b[39m         X,\n\u001b[32m    919\u001b[39m         y,\n\u001b[32m    920\u001b[39m         lambda_seq=[h.lambda_ \u001b[38;5;28;01mfor\u001b[39;00m h \u001b[38;5;129;01min\u001b[39;00m path[\u001b[32m1\u001b[39m:-\u001b[32m1\u001b[39m]],\n\u001b[32m    921\u001b[39m         return_state_dicts=return_state_dicts,\n\u001b[32m    922\u001b[39m     )\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.local/lib/python3.13/site-packages/lassonet/interfaces.py:471\u001b[39m, in \u001b[36mBaseLassoNet.path\u001b[39m\u001b[34m(self, X, y, X_val, y_val, lambda_seq, lambda_max, return_state_dicts, callback, disable_lambda_warning)\u001b[39m\n\u001b[32m    469\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.model.selected_count() == \u001b[32m0\u001b[39m:\n\u001b[32m    470\u001b[39m     \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m471\u001b[39m last = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_train\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m    472\u001b[39m \u001b[43m    \u001b[49m\u001b[43mX_train\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    473\u001b[39m \u001b[43m    \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    474\u001b[39m \u001b[43m    \u001b[49m\u001b[43mX_val\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    475\u001b[39m \u001b[43m    \u001b[49m\u001b[43my_val\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    476\u001b[39m \u001b[43m    \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    477\u001b[39m \u001b[43m    \u001b[49m\u001b[43mlambda_\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcurrent_lambda\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    478\u001b[39m \u001b[43m    \u001b[49m\u001b[43mepochs\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mn_iters_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    479\u001b[39m \u001b[43m    \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m=\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    480\u001b[39m \u001b[43m    \u001b[49m\u001b[43mpatience\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mpatience_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    481\u001b[39m \u001b[43m    \u001b[49m\u001b[43mreturn_state_dict\u001b[49m\u001b[43m=\u001b[49m\u001b[43mreturn_state_dicts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    482\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    483\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m is_dense \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m.model.selected_count() < X_train.shape[\u001b[32m1\u001b[39m]:\n\u001b[32m    484\u001b[39m     is_dense = \u001b[38;5;28;01mFalse\u001b[39;00m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.local/lib/python3.13/site-packages/lassonet/interfaces.py:317\u001b[39m, in \u001b[36mBaseLassoNet._train\u001b[39m\u001b[34m(self, X_train, y_train, X_val, y_val, batch_size, epochs, lambda_, optimizer, return_state_dict, patience)\u001b[39m\n\u001b[32m    314\u001b[39m         loss += ans.item() * batch_size / n_train\n\u001b[32m    315\u001b[39m         \u001b[38;5;28;01mreturn\u001b[39;00m ans\n\u001b[32m--> \u001b[39m\u001b[32m317\u001b[39m     \u001b[43moptimizer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    318\u001b[39m     model.prox(\n\u001b[32m    319\u001b[39m         lambda_=lambda_ * optimizer.param_groups[\u001b[32m0\u001b[39m][\u001b[33m\"\u001b[39m\u001b[33mlr\u001b[39m\u001b[33m\"\u001b[39m],\n\u001b[32m    320\u001b[39m         M=\u001b[38;5;28mself\u001b[39m.M,\n\u001b[32m    321\u001b[39m     )\n\u001b[32m    323\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m epoch == \u001b[32m0\u001b[39m:\n\u001b[32m    324\u001b[39m     \u001b[38;5;66;03m# fallback to running loss of first epoch\u001b[39;00m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/sysapps/ubuntu-applications/pytorch/2.7.0/venv/lib/python3.13/site-packages/torch/optim/optimizer.py:485\u001b[39m, in \u001b[36mOptimizer.profile_hook_step.<locals>.wrapper\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m    480\u001b[39m         \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m    481\u001b[39m             \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[32m    482\u001b[39m                 \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m    483\u001b[39m             )\n\u001b[32m--> \u001b[39m\u001b[32m485\u001b[39m out = \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    486\u001b[39m \u001b[38;5;28mself\u001b[39m._optimizer_step_code()\n\u001b[32m    488\u001b[39m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/sysapps/ubuntu-applications/pytorch/2.7.0/venv/lib/python3.13/site-packages/torch/optim/optimizer.py:79\u001b[39m, in \u001b[36m_use_grad_for_differentiable.<locals>._use_grad\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m     77\u001b[39m     torch.set_grad_enabled(\u001b[38;5;28mself\u001b[39m.defaults[\u001b[33m\"\u001b[39m\u001b[33mdifferentiable\u001b[39m\u001b[33m\"\u001b[39m])\n\u001b[32m     78\u001b[39m     torch._dynamo.graph_break()\n\u001b[32m---> \u001b[39m\u001b[32m79\u001b[39m     ret = \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m     80\u001b[39m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[32m     81\u001b[39m     torch._dynamo.graph_break()\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/sysapps/ubuntu-applications/pytorch/2.7.0/venv/lib/python3.13/site-packages/torch/optim/sgd.py:114\u001b[39m, in \u001b[36mSGD.step\u001b[39m\u001b[34m(self, closure)\u001b[39m\n\u001b[32m    112\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m closure \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m    113\u001b[39m     \u001b[38;5;28;01mwith\u001b[39;00m torch.enable_grad():\n\u001b[32m--> \u001b[39m\u001b[32m114\u001b[39m         loss = \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    116\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.param_groups:\n\u001b[32m    117\u001b[39m     params: \u001b[38;5;28mlist\u001b[39m[Tensor] = []\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.local/lib/python3.13/site-packages/lassonet/interfaces.py:294\u001b[39m, in \u001b[36mBaseLassoNet._train.<locals>.closure\u001b[39m\u001b[34m()\u001b[39m\n\u001b[32m    292\u001b[39m \u001b[38;5;28;01mnonlocal\u001b[39;00m loss\n\u001b[32m    293\u001b[39m optimizer.zero_grad()\n\u001b[32m--> \u001b[39m\u001b[32m294\u001b[39m crit = \u001b[38;5;28mself\u001b[39m.criterion(\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train\u001b[49m\u001b[43m[\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m, y_train[batch])\n\u001b[32m    295\u001b[39m ans = (\n\u001b[32m    296\u001b[39m     crit\n\u001b[32m    297\u001b[39m     + \u001b[38;5;28mself\u001b[39m.gamma * model.l2_regularization()\n\u001b[32m    298\u001b[39m     + \u001b[38;5;28mself\u001b[39m.gamma_skip * model.l2_regularization_skip()\n\u001b[32m    299\u001b[39m )\n\u001b[32m    300\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m torch.isfinite(ans):\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/sysapps/ubuntu-applications/pytorch/2.7.0/venv/lib/python3.13/site-packages/torch/nn/modules/module.py:1751\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1749\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m   1750\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1751\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/sysapps/ubuntu-applications/pytorch/2.7.0/venv/lib/python3.13/site-packages/torch/nn/modules/module.py:1762\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1757\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m   1758\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m   1759\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m   1760\u001b[39m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m   1761\u001b[39m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1762\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1764\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m   1765\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.local/lib/python3.13/site-packages/lassonet/model.py:44\u001b[39m, in \u001b[36mLassoNet.forward\u001b[39m\u001b[34m(self, inp)\u001b[39m\n\u001b[32m     42\u001b[39m result = \u001b[38;5;28mself\u001b[39m.skip(inp)\n\u001b[32m     43\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m theta \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.layers:\n\u001b[32m---> \u001b[39m\u001b[32m44\u001b[39m     current_layer = \u001b[43mtheta\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcurrent_layer\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m     45\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m theta \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m.layers[-\u001b[32m1\u001b[39m]:\n\u001b[32m     46\u001b[39m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.dropout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "\u001b[36mFile \u001b[39m\u001b[32m/sysapps/ubuntu-applications/pytorch/2.7.0/venv/lib/python3.13/site-packages/torch/nn/modules/module.py:1747\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m   1744\u001b[39m             tracing_state.pop_scope()\n\u001b[32m   1745\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m result\n\u001b[32m-> \u001b[39m\u001b[32m1747\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_wrapped_call_impl\u001b[39m(\u001b[38;5;28mself\u001b[39m, *args, **kwargs):\n\u001b[32m   1748\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m   1749\u001b[39m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n",
      "\u001b[31mKeyboardInterrupt\u001b[39m: "
     ]
    }
   ],
   "source": [
    "from sklearn.preprocessing import PolynomialFeatures, StandardScaler\n",
    "from itertools import combinations_with_replacement, compress\n",
    "from lassonet import LassoNetRegressor, LassoNetRegressorCV\n",
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "import torch \n",
    "import numpy as np\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "def main_and_interactions(indices):\n",
    "    result = []\n",
    "    # main effects\n",
    "    result.extend([[i] for i in indices])\n",
    "    # interactions (with replacement, so includes [i,i])\n",
    "    result.extend([list(comb) for comb in combinations_with_replacement(indices, 2)])\n",
    "    return (result)\n",
    "\n",
    "name = ['only_main_data', 'weak_main_data', 'inter_no_overlap_data', 'inter_mild_overlap_data', 'inter_strong_overlap_data', 'only_inter_data']\n",
    "name = ['only_main300_data', 'weak_main300_data', 'inter_no_overlap300_data', 'inter_mild_overlap300_data', 'inter_strong_overlap300_data', 'only_inter300_data']\n",
    "\n",
    "tml = [[[0], [1], [2], [3]], [[0], [1], [2], [3]], [[0], [1], [2]], [[0], [1], [2]], \n",
    "      [[0], [1], [2]], []]\n",
    "til = [[], [], [[3, 4]], [[2,3]], [[1,2]], [[1,2], [3,4]]] \n",
    "r = 50\n",
    "\n",
    "for j in range(len(name)):\n",
    "    _dict = torch.load('../data/'+ name[j] + '.pt', weights_only= True)\n",
    "    X_train = np.array(_dict['X_train'])\n",
    "    y_train = np.array(_dict['y_train'])\n",
    "    \n",
    "    TPR = []\n",
    "    FPR = []\n",
    "    for i in range(r):\n",
    "        if i%20 == 0:\n",
    "            print(i-1)\n",
    "        poly = PolynomialFeatures(degree=2, include_bias=False)\n",
    "        X_train[i] = (X_train[i] - X_train[i].mean())/X_train.std()\n",
    "        X_basis = poly.fit_transform(X_train[i])\n",
    "        \n",
    "        model = LassoNetRegressorCV(\n",
    "            hidden_dims=(10, 10),   # neural net architecture\n",
    "            M=1,                   # hierarchy parameter (linear vs nonlinear strength)\n",
    "            path_multiplier=2,      # geometric progression along regularization path\n",
    "            verbose=0,              # verbosity\n",
    "            patience=10,            # early stopping for training\n",
    "            batch_size=128,\n",
    "            torch_seed=42,\n",
    "        ) \n",
    "        path = model.path(X_basis, y_train[i])\n",
    "        pdim = X_train[i].shape[1]\n",
    "        indice = [k for k in range(pdim)]\n",
    "        group_ID = main_and_interactions(indice)\n",
    "        \n",
    "        val_loss_list = np.zeros((len(path)))\n",
    "        for m in range(len(path)):\n",
    "            val_loss_list[m] = (path[m].val_loss)\n",
    "        \n",
    "        min_vloss_idx = np.argmin(val_loss_list).item()\n",
    "        pred_set = list(compress(group_ID, path[min_vloss_idx].selected))\n",
    "        tvalue, fvalue = detection(tml[j], til[j], pred_set, pdim)\n",
    "        TPR.append(tvalue)\n",
    "        FPR.append(fvalue)\n",
    "\n",
    "    print(f\"TPR Mean (DNN): {np.mean(TPR):.4f}, Std: {np.std(TPR):.4f}\")\n",
    "    print(f\"FPR Mean (SDAM): {np.mean(FPR):.4f}, Std: {np.std(FPR):.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9c212214",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "807a7475",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
