{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "863a53bf-8962-45c8-baae-ba17eb210e85",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING! The apex is not installed so fp16 is not available.\n"
     ]
    }
   ],
   "source": [
    "from nodegam.sklearn import NodeGAMRegressor\n",
    "from sklearn.model_selection import train_test_split\n",
    "import pandas as pd\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def normalize(x):\n",
    "    x_min = x.min(dim=0, keepdim=True).values  # 每欄最小值，維持維度\n",
    "    x_max = x.max(dim=0, keepdim=True).values  # 每欄最大值，維持維度\n",
    "\n",
    "    x_normalized = (x - x_min) / (x_max - x_min)\n",
    "\n",
    "    return x_normalized\n",
    "\n",
    "import os\n",
    "os.environ['CUDA_LAUNCH_BLOCKING']=\"1\"\n",
    "os.environ['TORCH_USE_CUDA_DSA'] = \"1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6fb41325-5453-494d-b8bd-b0a1ee8824a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Down sampling on training data ... \n",
      "X_train size: torch.Size([100, 30]), y_train size: torch.Size([100])\n",
      "X_test size: torch.Size([400, 30]), y_test size: torch.Size([400])\n",
      "X_val.shape: torch.Size([50, 30]) | y_val.shape: torch.Size([50])\n"
     ]
    }
   ],
   "source": [
    "#SDAM/data/Chip_dict.pt\n",
    "pathname = './data/Chip_dict.pt'\n",
    "#pathname = '/Users/a080528/Desktop/Github/SDAM/data/Chip_dict.py'\n",
    "\n",
    "Chip_data = torch.load(pathname, weights_only = True)\n",
    "X_train, y_train, X_test, y_test = Chip_data['Xtrain'], Chip_data['ytrain'], Chip_data['Xtest'], Chip_data['ytest']\n",
    "X_train, y_train, X_test, y_test = X_train.numpy(), y_train.numpy().reshape(-1), X_test.numpy(), y_test.numpy().reshape(-1)\n",
    "\n",
    "print(\"Down sampling on training data ... \")\n",
    "subsize = 150\n",
    "X_train, y_train = X_train[:subsize, :], y_train[:subsize]\n",
    "\n",
    "sparse_array_train = np.random.uniform(low=-1, high=1, size=(X_train.shape[0], 21))\n",
    "X_train = np.concatenate([X_train, sparse_array_train], axis = 1)\n",
    "X_train = torch.tensor(X_train).to(torch.float32)\n",
    "y_train = torch.tensor(y_train).to(torch.float32)\n",
    "\n",
    "X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.333, random_state=42)\n",
    "\n",
    "sparse_array_test = np.random.uniform(low=-1, high=1, size=(X_test.shape[0], 21))\n",
    "X_test = np.concatenate([X_test, sparse_array_test], axis = 1)\n",
    "X_test = torch.tensor(X_test).to(torch.float32)\n",
    "y_test = torch.tensor(y_test).to(torch.float32)\n",
    "\n",
    "## Normalization\n",
    "X_train, X_val, X_test = normalize(X_train), normalize(X_val), normalize(X_test)\n",
    "\n",
    "\n",
    "print(f\"X_train size: {X_train.shape}, y_train size: {y_train.shape}\")\n",
    "print(f\"X_test size: {X_test.shape}, y_test size: {y_test.shape}\")\n",
    "print(f\"X_val.shape: {X_val.shape} | y_val.shape: {y_val.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa3043c5-2a4d-4bdc-8915-a3f6e07596c5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fdf04f66-3ec3-495c-afc4-24050437d638",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X_trn.shape: torch.Size([353, 10]) | y_trn.shape: torch.Size([353])\n",
      "X_val.shape: torch.Size([44, 10]) | y_val.shape: torch.Size([44])\n",
      "X_val.shape: torch.Size([45, 10]) | y_val.shape: torch.Size([45])\n"
     ]
    }
   ],
   "source": [
    "from sklearn.datasets import load_diabetes\n",
    "\n",
    "diabetes = load_diabetes()\n",
    "\n",
    "# Accessing features and target\n",
    "X = diabetes.data\n",
    "sparse_array = np.random.uniform(low=-1, high=1, size=(X.shape[0], 40))\n",
    "X = np.concatenate([X, sparse_array], axis = 1)\n",
    "X = torch.tensor(X).to(torch.float32)\n",
    "\n",
    "y = diabetes.target\n",
    "y = torch.tensor(y).to(torch.float32)\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
    "X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=0.5, random_state=42)\n",
    "\n",
    "# Normalization\n",
    "X_train, X_val, X_test = normalize(X_train), normalize(X_val), normalize(X_test)\n",
    "\n",
    "print(f\"X_trn.shape: {X_train.shape} | y_trn.shape: {y_train.shape}\")\n",
    "print(f\"X_val.shape: {X_val.shape} | y_val.shape: {y_val.shape}\")\n",
    "print(f\"X_val.shape: {X_test.shape} | y_val.shape: {y_test.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de351ce9-335a-4337-a122-309225463c54",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e6637547-b221-46ff-8f90-39c5ef7b7866",
   "metadata": {},
   "source": [
    "## Modeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2e61e832-a718-4f73-bec8-95b457500748",
   "metadata": {},
   "outputs": [],
   "source": [
    "case = 'wine'\n",
    "X_train = pd.read_csv('./Additional_exp/'+case+'_Train.csv')\n",
    "y_train = X_train['y']\n",
    "X_val = pd.read_csv('./Additional_exp/'+case+'_Valid.csv')\n",
    "y_val = X_val['y']\n",
    "X_test = pd.read_csv('./Additional_exp/'+case+'_Test.csv')\n",
    "y_test = X_test['y']\n",
    "X_train.drop(['y'], axis=1, inplace = True)\n",
    "X_val.drop(['y'], axis=1, inplace = True)\n",
    "X_test.drop(['y'], axis=1, inplace = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30f598b4-05e9-4df5-b636-5e3e68090da1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed3508f0-ab9b-414e-9df2-d4843f2dd2ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = NodeGAMRegressor(in_features = X_train.shape[1], ga2m = 0, verbose=0, device = 'cpu')\n",
    "model.fit(pd.DataFrame(X_train), np.array(y_train))\n",
    "pred_test = model.predict(pd.DataFrame(X_test))\n",
    "print(np.mean(np.square(pred_test - np.array(y_test))))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64a6663d-5513-49c3-a87d-8232e186f88b",
   "metadata": {},
   "source": [
    "## NAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "c0a90a6a-c0b4-4ba5-aa90-746a48967fa3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dnamite.models import DNAMiteRegressor\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "id": "6e8dc872-07fc-4b36-bbbb-666cd83a1ec9",
   "metadata": {},
   "outputs": [],
   "source": [
    "case = 'fico'\n",
    "X_train = pd.read_csv('./Additional_exp/'+case+'_Train.csv')\n",
    "y_train = X_train['y']\n",
    "X_val = pd.read_csv('./Additional_exp/'+case+'_Valid.csv')\n",
    "y_val = X_val['y']\n",
    "X_test = pd.read_csv('./Additional_exp/'+case+'_Test.csv')\n",
    "y_test = X_test['y']\n",
    "X_train.drop(['y'], axis=1, inplace = True)\n",
    "X_val.drop(['y'], axis=1, inplace = True)\n",
    "X_test.drop(['y'], axis=1, inplace = True)\n",
    "\n",
    "rec = pd.Index([str(i) for i in range(len(X_test.columns))], dtype = 'object')\n",
    "X_test = X_test.set_axis(rec, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "id": "ec842a1e-6412-44c8-b800-df294ff7dd09",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-8 {\n",
       "  /* Definition of color scheme common for light and dark mode */\n",
       "  --sklearn-color-text: #000;\n",
       "  --sklearn-color-text-muted: #666;\n",
       "  --sklearn-color-line: gray;\n",
       "  /* Definition of color scheme for unfitted estimators */\n",
       "  --sklearn-color-unfitted-level-0: #fff5e6;\n",
       "  --sklearn-color-unfitted-level-1: #f6e4d2;\n",
       "  --sklearn-color-unfitted-level-2: #ffe0b3;\n",
       "  --sklearn-color-unfitted-level-3: chocolate;\n",
       "  /* Definition of color scheme for fitted estimators */\n",
       "  --sklearn-color-fitted-level-0: #f0f8ff;\n",
       "  --sklearn-color-fitted-level-1: #d4ebff;\n",
       "  --sklearn-color-fitted-level-2: #b3dbfd;\n",
       "  --sklearn-color-fitted-level-3: cornflowerblue;\n",
       "\n",
       "  /* Specific color for light theme */\n",
       "  --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
       "  --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-icon: #696969;\n",
       "\n",
       "  @media (prefers-color-scheme: dark) {\n",
       "    /* Redefinition of color scheme for dark theme */\n",
       "    --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
       "    --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-icon: #878787;\n",
       "  }\n",
       "}\n",
       "\n",
       "#sk-container-id-8 {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "#sk-container-id-8 pre {\n",
       "  padding: 0;\n",
       "}\n",
       "\n",
       "#sk-container-id-8 input.sk-hidden--visually {\n",
       "  border: 0;\n",
       "  clip: rect(1px 1px 1px 1px);\n",
       "  clip: rect(1px, 1px, 1px, 1px);\n",
       "  height: 1px;\n",
       "  margin: -1px;\n",
       "  overflow: hidden;\n",
       "  padding: 0;\n",
       "  position: absolute;\n",
       "  width: 1px;\n",
       "}\n",
       "\n",
       "#sk-container-id-8 div.sk-dashed-wrapped {\n",
       "  border: 1px dashed var(--sklearn-color-line);\n",
       "  margin: 0 0.4em 0.5em 0.4em;\n",
       "  box-sizing: border-box;\n",
       "  padding-bottom: 0.4em;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "#sk-container-id-8 div.sk-container {\n",
       "  /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
       "     but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
       "     so we also need the `!important` here to be able to override the\n",
       "     default hidden behavior on the sphinx rendered scikit-learn.org.\n",
       "     See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
       "  display: inline-block !important;\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-8 div.sk-text-repr-fallback {\n",
       "  display: none;\n",
       "}\n",
       "\n",
       "div.sk-parallel-item,\n",
       "div.sk-serial,\n",
       "div.sk-item {\n",
       "  /* draw centered vertical line to link estimators */\n",
       "  background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
       "  background-size: 2px 100%;\n",
       "  background-repeat: no-repeat;\n",
       "  background-position: center center;\n",
       "}\n",
       "\n",
       "/* Parallel-specific style estimator block */\n",
       "\n",
       "#sk-container-id-8 div.sk-parallel-item::after {\n",
       "  content: \"\";\n",
       "  width: 100%;\n",
       "  border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
       "  flex-grow: 1;\n",
       "}\n",
       "\n",
       "#sk-container-id-8 div.sk-parallel {\n",
       "  display: flex;\n",
       "  align-items: stretch;\n",
       "  justify-content: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-8 div.sk-parallel-item {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       "\n",
       "#sk-container-id-8 div.sk-parallel-item:first-child::after {\n",
       "  align-self: flex-end;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-8 div.sk-parallel-item:last-child::after {\n",
       "  align-self: flex-start;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-8 div.sk-parallel-item:only-child::after {\n",
       "  width: 0;\n",
       "}\n",
       "\n",
       "/* Serial-specific style estimator block */\n",
       "\n",
       "#sk-container-id-8 div.sk-serial {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "  align-items: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  padding-right: 1em;\n",
       "  padding-left: 1em;\n",
       "}\n",
       "\n",
       "\n",
       "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
       "clickable and can be expanded/collapsed.\n",
       "- Pipeline and ColumnTransformer use this feature and define the default style\n",
       "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
       "*/\n",
       "\n",
       "/* Pipeline and ColumnTransformer style (default) */\n",
       "\n",
       "#sk-container-id-8 div.sk-toggleable {\n",
       "  /* Default theme specific background. It is overwritten whether we have a\n",
       "  specific estimator or a Pipeline/ColumnTransformer */\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "/* Toggleable label */\n",
       "#sk-container-id-8 label.sk-toggleable__label {\n",
       "  cursor: pointer;\n",
       "  display: flex;\n",
       "  width: 100%;\n",
       "  margin-bottom: 0;\n",
       "  padding: 0.5em;\n",
       "  box-sizing: border-box;\n",
       "  text-align: center;\n",
       "  align-items: start;\n",
       "  justify-content: space-between;\n",
       "  gap: 0.5em;\n",
       "}\n",
       "\n",
       "#sk-container-id-8 label.sk-toggleable__label .caption {\n",
       "  font-size: 0.6rem;\n",
       "  font-weight: lighter;\n",
       "  color: var(--sklearn-color-text-muted);\n",
       "}\n",
       "\n",
       "#sk-container-id-8 label.sk-toggleable__label-arrow:before {\n",
       "  /* Arrow on the left of the label */\n",
       "  content: \"▸\";\n",
       "  float: left;\n",
       "  margin-right: 0.25em;\n",
       "  color: var(--sklearn-color-icon);\n",
       "}\n",
       "\n",
       "#sk-container-id-8 label.sk-toggleable__label-arrow:hover:before {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "/* Toggleable content - dropdown */\n",
       "\n",
       "#sk-container-id-8 div.sk-toggleable__content {\n",
       "  display: none;\n",
       "  text-align: left;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-8 div.sk-toggleable__content.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-8 div.sk-toggleable__content pre {\n",
       "  margin: 0.2em;\n",
       "  border-radius: 0.25em;\n",
       "  color: var(--sklearn-color-text);\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-8 div.sk-toggleable__content.fitted pre {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-8 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
       "  /* Expand drop-down */\n",
       "  display: block;\n",
       "  width: 100%;\n",
       "  overflow: visible;\n",
       "}\n",
       "\n",
       "#sk-container-id-8 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
       "  content: \"▾\";\n",
       "}\n",
       "\n",
       "/* Pipeline/ColumnTransformer-specific style */\n",
       "\n",
       "#sk-container-id-8 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-8 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator-specific style */\n",
       "\n",
       "/* Colorize estimator box */\n",
       "#sk-container-id-8 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-8 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-8 div.sk-label label.sk-toggleable__label,\n",
       "#sk-container-id-8 div.sk-label label {\n",
       "  /* The background is the default theme color */\n",
       "  color: var(--sklearn-color-text-on-default-background);\n",
       "}\n",
       "\n",
       "/* On hover, darken the color of the background */\n",
       "#sk-container-id-8 div.sk-label:hover label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "/* Label box, darken color on hover, fitted */\n",
       "#sk-container-id-8 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator label */\n",
       "\n",
       "#sk-container-id-8 div.sk-label label {\n",
       "  font-family: monospace;\n",
       "  font-weight: bold;\n",
       "  display: inline-block;\n",
       "  line-height: 1.2em;\n",
       "}\n",
       "\n",
       "#sk-container-id-8 div.sk-label-container {\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "/* Estimator-specific */\n",
       "#sk-container-id-8 div.sk-estimator {\n",
       "  font-family: monospace;\n",
       "  border: 1px dotted var(--sklearn-color-border-box);\n",
       "  border-radius: 0.25em;\n",
       "  box-sizing: border-box;\n",
       "  margin-bottom: 0.5em;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-8 div.sk-estimator.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "/* on hover */\n",
       "#sk-container-id-8 div.sk-estimator:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-8 div.sk-estimator.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
       "\n",
       "/* Common style for \"i\" and \"?\" */\n",
       "\n",
       ".sk-estimator-doc-link,\n",
       "a:link.sk-estimator-doc-link,\n",
       "a:visited.sk-estimator-doc-link {\n",
       "  float: right;\n",
       "  font-size: smaller;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1em;\n",
       "  height: 1em;\n",
       "  width: 1em;\n",
       "  text-decoration: none !important;\n",
       "  margin-left: 0.5em;\n",
       "  text-align: center;\n",
       "  /* unfitted */\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted,\n",
       "a:link.sk-estimator-doc-link.fitted,\n",
       "a:visited.sk-estimator-doc-link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "/* Span, style for the box shown on hovering the info icon */\n",
       ".sk-estimator-doc-link span {\n",
       "  display: none;\n",
       "  z-index: 9999;\n",
       "  position: relative;\n",
       "  font-weight: normal;\n",
       "  right: .2ex;\n",
       "  padding: .5ex;\n",
       "  margin: .5ex;\n",
       "  width: min-content;\n",
       "  min-width: 20ex;\n",
       "  max-width: 50ex;\n",
       "  color: var(--sklearn-color-text);\n",
       "  box-shadow: 2pt 2pt 4pt #999;\n",
       "  /* unfitted */\n",
       "  background: var(--sklearn-color-unfitted-level-0);\n",
       "  border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted span {\n",
       "  /* fitted */\n",
       "  background: var(--sklearn-color-fitted-level-0);\n",
       "  border: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link:hover span {\n",
       "  display: block;\n",
       "}\n",
       "\n",
       "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
       "\n",
       "#sk-container-id-8 a.estimator_doc_link {\n",
       "  float: right;\n",
       "  font-size: 1rem;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1rem;\n",
       "  height: 1rem;\n",
       "  width: 1rem;\n",
       "  text-decoration: none;\n",
       "  /* unfitted */\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "}\n",
       "\n",
       "#sk-container-id-8 a.estimator_doc_link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "#sk-container-id-8 a.estimator_doc_link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "#sk-container-id-8 a.estimator_doc_link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "\n",
       ".estimator-table summary {\n",
       "    padding: .5rem;\n",
       "    font-family: monospace;\n",
       "    cursor: pointer;\n",
       "}\n",
       "\n",
       ".estimator-table details[open] {\n",
       "    padding-left: 0.1rem;\n",
       "    padding-right: 0.1rem;\n",
       "    padding-bottom: 0.3rem;\n",
       "}\n",
       "\n",
       ".estimator-table .parameters-table {\n",
       "    margin-left: auto !important;\n",
       "    margin-right: auto !important;\n",
       "}\n",
       "\n",
       ".estimator-table .parameters-table tr:nth-child(odd) {\n",
       "    background-color: #fff;\n",
       "}\n",
       "\n",
       ".estimator-table .parameters-table tr:nth-child(even) {\n",
       "    background-color: #f6f6f6;\n",
       "}\n",
       "\n",
       ".estimator-table .parameters-table tr:hover {\n",
       "    background-color: #e0e0e0;\n",
       "}\n",
       "\n",
       ".estimator-table table td {\n",
       "    border: 1px solid rgba(106, 105, 104, 0.232);\n",
       "}\n",
       "\n",
       ".user-set td {\n",
       "    color:rgb(255, 94, 0);\n",
       "    text-align: left;\n",
       "}\n",
       "\n",
       ".user-set td.value pre {\n",
       "    color:rgb(255, 94, 0) !important;\n",
       "    background-color: transparent !important;\n",
       "}\n",
       "\n",
       ".default td {\n",
       "    color: black;\n",
       "    text-align: left;\n",
       "}\n",
       "\n",
       ".user-set td i,\n",
       ".default td i {\n",
       "    color: black;\n",
       "}\n",
       "\n",
       ".copy-paste-icon {\n",
       "    background-image: url(data:image/svg+xml;base64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCA0NDggNTEyIj48IS0tIUZvbnQgQXdlc29tZSBGcmVlIDYuNy4yIGJ5IEBmb250YXdlc29tZSAtIGh0dHBzOi8vZm9udGF3ZXNvbWUuY29tIExpY2Vuc2UgLSBodHRwczovL2ZvbnRhd2Vzb21lLmNvbS9saWNlbnNlL2ZyZWUgQ29weXJpZ2h0IDIwMjUgRm9udGljb25zLCBJbmMuLS0+PHBhdGggZD0iTTIwOCAwTDMzMi4xIDBjMTIuNyAwIDI0LjkgNS4xIDMzLjkgMTQuMWw2Ny45IDY3LjljOSA5IDE0LjEgMjEuMiAxNC4xIDMzLjlMNDQ4IDMzNmMwIDI2LjUtMjEuNSA0OC00OCA0OGwtMTkyIDBjLTI2LjUgMC00OC0yMS41LTQ4LTQ4bDAtMjg4YzAtMjYuNSAyMS41LTQ4IDQ4LTQ4ek00OCAxMjhsODAgMCAwIDY0LTY0IDAgMCAyNTYgMTkyIDAgMC0zMiA2NCAwIDAgNDhjMCAyNi41LTIxLjUgNDgtNDggNDhMNDggNTEyYy0yNi41IDAtNDgtMjEuNS00OC00OEwwIDE3NmMwLTI2LjUgMjEuNS00OCA0OC00OHoiLz48L3N2Zz4=);\n",
       "    background-repeat: no-repeat;\n",
       "    background-size: 14px 14px;\n",
       "    background-position: 0;\n",
       "    display: inline-block;\n",
       "    width: 14px;\n",
       "    height: 14px;\n",
       "    cursor: pointer;\n",
       "}\n",
       "</style><body><div id=\"sk-container-id-8\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>DNAMiteRegressor(random_state=20)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-8\" type=\"checkbox\" checked><label for=\"sk-estimator-id-8\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow\"><div><div>DNAMiteRegressor</div></div><div><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></div></label><div class=\"sk-toggleable__content fitted\" data-param-prefix=\"\">\n",
       "        <div class=\"estimator-table\">\n",
       "            <details>\n",
       "                <summary>Parameters</summary>\n",
       "                <table class=\"parameters-table\">\n",
       "                  <tbody>\n",
       "                    \n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('n_embed',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">n_embed&nbsp;</td>\n",
       "            <td class=\"value\">32</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('n_hidden',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">n_hidden&nbsp;</td>\n",
       "            <td class=\"value\">32</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('n_layers',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">n_layers&nbsp;</td>\n",
       "            <td class=\"value\">2</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('max_bins',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">max_bins&nbsp;</td>\n",
       "            <td class=\"value\">32</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('min_samples_per_bin',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">min_samples_per_bin&nbsp;</td>\n",
       "            <td class=\"value\">None</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('validation_size',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">validation_size&nbsp;</td>\n",
       "            <td class=\"value\">0.2</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('n_val_splits',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">n_val_splits&nbsp;</td>\n",
       "            <td class=\"value\">5</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('learning_rate',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">learning_rate&nbsp;</td>\n",
       "            <td class=\"value\">0.0005</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('max_epochs',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">max_epochs&nbsp;</td>\n",
       "            <td class=\"value\">100</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('batch_size',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">batch_size&nbsp;</td>\n",
       "            <td class=\"value\">128</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('device',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">device&nbsp;</td>\n",
       "            <td class=\"value\">&#x27;cpu&#x27;</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('kernel_size',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">kernel_size&nbsp;</td>\n",
       "            <td class=\"value\">5</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('kernel_weight',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">kernel_weight&nbsp;</td>\n",
       "            <td class=\"value\">3</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('pair_kernel_size',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">pair_kernel_size&nbsp;</td>\n",
       "            <td class=\"value\">3</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('pair_kernel_weight',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">pair_kernel_weight&nbsp;</td>\n",
       "            <td class=\"value\">3</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('monotone_constraints',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">monotone_constraints&nbsp;</td>\n",
       "            <td class=\"value\">None</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('num_pairs',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">num_pairs&nbsp;</td>\n",
       "            <td class=\"value\">0</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"default\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('verbosity',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">verbosity&nbsp;</td>\n",
       "            <td class=\"value\">0</td>\n",
       "        </tr>\n",
       "    \n",
       "\n",
       "        <tr class=\"user-set\">\n",
       "            <td><i class=\"copy-paste-icon\"\n",
       "                 onclick=\"copyToClipboard('random_state',\n",
       "                          this.parentElement.nextElementSibling)\"\n",
       "            ></i></td>\n",
       "            <td class=\"param\">random_state&nbsp;</td>\n",
       "            <td class=\"value\">20</td>\n",
       "        </tr>\n",
       "    \n",
       "                  </tbody>\n",
       "                </table>\n",
       "            </details>\n",
       "        </div>\n",
       "    </div></div></div></div></div><script>function copyToClipboard(text, element) {\n",
       "    // Get the parameter prefix from the closest toggleable content\n",
       "    const toggleableContent = element.closest('.sk-toggleable__content');\n",
       "    const paramPrefix = toggleableContent ? toggleableContent.dataset.paramPrefix : '';\n",
       "    const fullParamName = paramPrefix ? `${paramPrefix}${text}` : text;\n",
       "\n",
       "    const originalStyle = element.style;\n",
       "    const computedStyle = window.getComputedStyle(element);\n",
       "    const originalWidth = computedStyle.width;\n",
       "    const originalHTML = element.innerHTML.replace('Copied!', '');\n",
       "\n",
       "    navigator.clipboard.writeText(fullParamName)\n",
       "        .then(() => {\n",
       "            element.style.width = originalWidth;\n",
       "            element.style.color = 'green';\n",
       "            element.innerHTML = \"Copied!\";\n",
       "\n",
       "            setTimeout(() => {\n",
       "                element.innerHTML = originalHTML;\n",
       "                element.style = originalStyle;\n",
       "            }, 2000);\n",
       "        })\n",
       "        .catch(err => {\n",
       "            console.error('Failed to copy:', err);\n",
       "            element.style.color = 'red';\n",
       "            element.innerHTML = \"Failed!\";\n",
       "            setTimeout(() => {\n",
       "                element.innerHTML = originalHTML;\n",
       "                element.style = originalStyle;\n",
       "            }, 2000);\n",
       "        });\n",
       "    return false;\n",
       "}\n",
       "\n",
       "document.querySelectorAll('.fa-regular.fa-copy').forEach(function(element) {\n",
       "    const toggleableContent = element.closest('.sk-toggleable__content');\n",
       "    const paramPrefix = toggleableContent ? toggleableContent.dataset.paramPrefix : '';\n",
       "    const paramName = element.parentElement.nextElementSibling.textContent.trim();\n",
       "    const fullParamName = paramPrefix ? `${paramPrefix}${paramName}` : paramName;\n",
       "\n",
       "    element.setAttribute('title', fullParamName);\n",
       "});\n",
       "</script></body>"
      ],
      "text/plain": [
       "DNAMiteRegressor(random_state=20)"
      ]
     },
     "execution_count": 106,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = DNAMiteRegressor()\n",
    "model.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "id": "d7f7eb43-cdfd-410c-867b-61e6fca9d0e2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DNAMite RMSE: 0.4855042580182042\n"
     ]
    }
   ],
   "source": [
    "pred_test = model.predict(X_test)\n",
    "print(f\"DNAMite RMSE: {np.sqrt(np.mean((pred_test - y_test)**2))}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "id": "1525d00c-f663-4080-abd2-d0c542868438",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0       1.0\n",
       "1       1.0\n",
       "2       1.0\n",
       "3       1.0\n",
       "4       0.0\n",
       "       ... \n",
       "1970    0.0\n",
       "1971    0.0\n",
       "1972    1.0\n",
       "1973    0.0\n",
       "1974    0.0\n",
       "Name: y, Length: 1975, dtype: float64"
      ]
     },
     "execution_count": 109,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_test"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa27388c-7fa1-42cc-b2a1-a32faf7f925d",
   "metadata": {},
   "source": [
    "## GAMINET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a8f46a3f-64a0-4170-bb31-506789545208",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "PACKAGE_PARENT = '..'\n",
    "sys.path.append(PACKAGE_PARENT)\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from gaminet import GAMINetRegressor, GAMINetClassifier\n",
    "from gaminet.utils import local_visualize\n",
    "from gaminet.utils import global_visualize_density\n",
    "from gaminet.utils import feature_importance_visualize\n",
    "from gaminet.utils import plot_trajectory\n",
    "from gaminet.utils import plot_regularization\n",
    "import time\n",
    "import sklearn\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bff73516-3105-4654-bd7e-20b9755d893c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "eb73ec16-ad96-4ec0-bd45-a50b5f4c51b6",
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'GAMINetRegressor' object has no attribute '_validate_data'",
     "output_type": "error",
     "traceback": [
      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
      "\u001b[31mAttributeError\u001b[39m                            Traceback (most recent call last)",
      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 17\u001b[39m\n\u001b[32m      2\u001b[39m y_train = np.asarray(y_train).reshape(-\u001b[32m1\u001b[39m)\n\u001b[32m      4\u001b[39m model = GAMINetRegressor(interact_num=\u001b[32m10\u001b[39m,\n\u001b[32m      5\u001b[39m                  subnet_size_main_effect=(\u001b[32m20\u001b[39m, ) ,\n\u001b[32m      6\u001b[39m                  subnet_size_interaction=(\u001b[32m20\u001b[39m, \u001b[32m20\u001b[39m),\n\u001b[32m   (...)\u001b[39m\u001b[32m     14\u001b[39m                  verbose=\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[32m     15\u001b[39m                  random_state=\u001b[32m0\u001b[39m)\n\u001b[32m---> \u001b[39m\u001b[32m17\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m     18\u001b[39m pred_test = model.predict(np.asarray(X_test))\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.local/lib/python3.13/site-packages/gaminet/api.py:231\u001b[39m, in \u001b[36mGAMINetRegressor.fit\u001b[39m\u001b[34m(self, x, y, sample_weight)\u001b[39m\n\u001b[32m    213\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mfit\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, y, sample_weight=\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[32m    214\u001b[39m \u001b[38;5;250m    \u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m    215\u001b[39m \u001b[33;03m    Fit GAMINetRegressor model.\u001b[39;00m\n\u001b[32m    216\u001b[39m \n\u001b[32m   (...)\u001b[39m\u001b[32m    229\u001b[39m \u001b[33;03m        Fitted Estimator.\u001b[39;00m\n\u001b[32m    230\u001b[39m \u001b[33;03m    \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m231\u001b[39m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_init_fit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msample_weight\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    232\u001b[39m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._fit()\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.local/lib/python3.13/site-packages/gaminet/base.py:561\u001b[39m, in \u001b[36mGAMINet._init_fit\u001b[39m\u001b[34m(self, x, y, sample_weight, stratified)\u001b[39m\n\u001b[32m    557\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.activation_func \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m ACTIVATIONS:\n\u001b[32m    558\u001b[39m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mThe activation \u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[33m'\u001b[39m\u001b[33m is not supported. Supported activations are \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[33m.\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m    559\u001b[39m                 % (\u001b[38;5;28mself\u001b[39m.activation_func, \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28msorted\u001b[39m(ACTIVATIONS))))\n\u001b[32m--> \u001b[39m\u001b[32m561\u001b[39m x, y, sample_weight = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_validate_input\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msample_weight\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m    562\u001b[39m n_jobs = cpu_count() \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.n_jobs == -\u001b[32m1\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mmin\u001b[39m(\u001b[38;5;28mint\u001b[39m(\u001b[38;5;28mself\u001b[39m.n_jobs), cpu_count())\n\u001b[32m    563\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.reg_clarity > \u001b[32m0\u001b[39m:\n",
      "\u001b[36mFile \u001b[39m\u001b[32m~/.local/lib/python3.13/site-packages/gaminet/api.py:90\u001b[39m, in \u001b[36mGAMINetRegressor._validate_input\u001b[39m\u001b[34m(self, x, y, sample_weight)\u001b[39m\n\u001b[32m     64\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_validate_input\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, y, sample_weight):\n\u001b[32m     65\u001b[39m \u001b[38;5;250m    \u001b[39m\u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m     66\u001b[39m \u001b[33;03m    Internal function for validating the inputs of the fit function.\u001b[39;00m\n\u001b[32m     67\u001b[39m \n\u001b[32m   (...)\u001b[39m\u001b[32m     88\u001b[39m \u001b[33;03m        Sample weight.\u001b[39;00m\n\u001b[32m     89\u001b[39m \u001b[33;03m    \"\"\"\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m90\u001b[39m     x, y = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_validate_data\u001b[49m(x, y, y_numeric=\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[32m     91\u001b[39m     \u001b[38;5;28;01mif\u001b[39;00m y.ndim == \u001b[32m2\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m y.shape[\u001b[32m1\u001b[39m] == \u001b[32m1\u001b[39m:\n\u001b[32m     92\u001b[39m         y = column_or_1d(y, warn=\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
      "\u001b[31mAttributeError\u001b[39m: 'GAMINetRegressor' object has no attribute '_validate_data'"
     ]
    }
   ],
   "source": [
    "X_train = np.asarray(X_train)\n",
    "y_train = np.asarray(y_train).reshape(-1)\n",
    "\n",
    "model = GAMINetRegressor(interact_num=10,\n",
    "                 subnet_size_main_effect=(20, ) ,\n",
    "                 subnet_size_interaction=(20, 20),\n",
    "                 max_epochs=(1000, 1000, 1000),\n",
    "                 learning_rates=(0.001, 0.001, 0.0001),\n",
    "                 early_stop_thres=(\"auto\", \"auto\", \"auto\"),\n",
    "                 batch_size=1000,\n",
    "                 reg_clarity=1,\n",
    "                 loss_threshold=0.01,\n",
    "                 warm_start=True,\n",
    "                 verbose=False,\n",
    "                 random_state=0)\n",
    "\n",
    "model.fit(X_train, y_train)\n",
    "pred_test = model.predict(np.asarray(X_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4643ca05-262e-48b6-a61b-0fff43e2e030",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "acfb56c8-ab84-4401-8096-80c5cfea3b55",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Image V1-cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "425a9a96-8d94-47be-a125-d5f875d2a60b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import h5py\n",
    "path = '/home/users/yhung7/SDAM/data/'\n",
    "\n",
    "with h5py.File(path+'EstimatedResponses.mat', 'r') as f:\n",
    "    # You can inspect the contents of the file\n",
    "    # For example, to list all top-level keys (variables)\n",
    "    # To access a specific variable, for example, named 'data_variable'\n",
    "    # Data will typically be stored as an HDF5 dataset\n",
    "    # You might need to adjust the key based on your file's structure\n",
    "    YTrainS1 = f['dataTrnS1'][:] # The [:] loads the entire dataset into memory as a NumPy array\n",
    "    YTrainS2 = f['dataTrnS2'][:] # The [:] loads the entire dataset into memory as a NumPy array\n",
    "    YValS1 = f['dataValS1'][:] # The [:] loads the entire dataset into memory as a NumPy array\n",
    "    YValS2 = f['dataValS2'][:] # The [:] loads the entire dataset into memory as a NumPy array\n",
    "    voxS1 = f['voxIdxS1'][:] # The [:] loads the entire dataset into memory as a NumPy array\n",
    "    voxS2 = f['voxIdxS2'][:] # The [:] loads the entire dataset into memory as a NumPy array\n",
    "    roiS1 = f['roiS1'][:] # The [:] loads the entire dataset into memory as a NumPy array = f['roiS1'][:] # The [:] loads the entire dataset into memory as a NumPy array\n",
    "    roiS2 = f['roiS2'][:] \n",
    "\n",
    "y_trn = YTrainS1[:, np.where(roiS1 == 1)[1]]\n",
    "y_trn = torch.tensor(y_trn)\n",
    "y_val = YValS1[:, np.where(roiS1 == 1)[1]]\n",
    "y_val = torch.tensor(y_val)\n",
    "\n",
    "X_trn = np.load(path+'complex_cell_train.npy')\n",
    "X_trn = torch.tensor(X_trn).to(torch.float32)\n",
    "X_val = np.load(path+'complex_cell_valid.npy')\n",
    "X_val = torch.tensor(X_val).to(torch.float32)\n",
    "\n",
    "y_first_voxel = y_trn[:, 0].to(torch.float32)\n",
    "y_second_voxel = y_trn[:, 1].to(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "d65d147b-1261-4256-afb4-aad5b6569b91",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "c62a2d8f-d3d4-4542-9044-c3ebf84bd7c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cpu:0\")\n",
    "X_sub = X_trn[:sub_sample, :sub_feature].to(device)\n",
    "y_sub = y_trn[:sub_sample, v1_idx].to(torch.float32).view(-1, 1).to(device)\n",
    "X_val_sub = X_trn[sub_sample:sub_sample+100, :sub_feature].to(device)\n",
    "y_val_sub = y_trn[sub_sample:sub_sample+100, v1_idx].to(torch.float32).view(-1, 1).to(device)\n",
    "\n",
    "\n",
    "## Standardization\n",
    "X_test_sub = X_test[:sub_sample, :sub_feature].to(device)\n",
    "y_test_sub = y_val[:sub_sample, v1_idx].to(torch.float32).view(-1, 1).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16891c01-8acc-45c4-9ae7-166ed5042ce1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_29842/1708979526.py:2: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments.\n",
      "  model.fit(pd.DataFrame(X_sub), np.array(y_sub))\n"
     ]
    }
   ],
   "source": [
    "\n",
    "model = DNAMiteRegressor()\n",
    "model.fit(pd.DataFrame(X_sub), np.array(y_sub))\n",
    "pred_test = model.predict(pd.DataFrame(X_test_sub))\n",
    "print(np.mean(np.square(pred_test - np.array(y_test_sub))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f07ecc1b-7262-45ee-995e-54218bfa5293",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
