{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9ca6ab81",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "os.environ[\"R_HOME\"] = r\"C:\\Program Files\\R\\R-4.4.2\"  # Use raw string for Windows paths\n",
    "sys.path.append(r\"c:\\Document\\Serieux\\Travail\\python_work\\cEBNM_torch\\py\")\n",
    "\n",
    "\n",
    "from utils import * \n",
    "# Core libraries\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import warnings\n",
    "from scipy.special import logsumexp\n",
    "import multiprocessing\n",
    "\n",
    "# scikit-learn\n",
    "from sklearn.linear_model import LassoCV, RidgeCV\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.metrics import mean_squared_error\n",
    "\n",
    "# rpy2 for R integration\n",
    "import rpy2.robjects as ro\n",
    "from rpy2.robjects import numpy2ri, r\n",
    "from rpy2.robjects.packages import importr\n",
    "\n",
    "# Activate NumPy-to-R conversion\n",
    "numpy2ri.activate()\n",
    "ashr = importr(\"ashr\")\n",
    "\n",
    "# torch (if used later in the project)\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "# Optional: joblib for parallelism (only if needed)\n",
    "from joblib import Parallel, delayed\n",
    "\n",
    "# Local utilities (if you have a utils.py file)\n",
    "from utils import *\n",
    "\n",
    "# Define dataset class that includes observation noise\n",
    "class DensityRegressionDataset(Dataset):\n",
    "    def __init__(self, X, betahat, sebetahat):\n",
    "        self.X = torch.tensor(X, dtype=torch.float32)\n",
    "        self.betahat = torch.tensor(betahat, dtype=torch.float32)\n",
    "        self.sebetahat = torch.tensor(sebetahat, dtype=torch.float32)  # Noise level for each observation\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.X)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.betahat[idx], self.sebetahat[idx]  # Return the noise_std (sebetahat) as well\n",
    "\n",
    "# Mixture Density Network\n",
    "class MDN(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, n_gaussians, n_layers=4):\n",
    "        super(MDN, self).__init__()\n",
    "        \n",
    "        # Input layer\n",
    "        self.fc_in = nn.Linear(input_dim, hidden_dim)\n",
    "        \n",
    "        # Hidden layers\n",
    "        self.hidden_layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(n_layers)])\n",
    "        \n",
    "        # Output layers for the Gaussian parameters\n",
    "        self.pi = nn.Linear(hidden_dim, n_gaussians)  # Mixing coefficients (weights)\n",
    "        self.mu = nn.Linear(hidden_dim, n_gaussians)  # Means of Gaussians\n",
    "        self.log_sigma = nn.Linear(hidden_dim, n_gaussians)  # Log of standard deviations\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.relu(self.fc_in(x))\n",
    "        \n",
    "        # Passing through each hidden layer\n",
    "        for layer in self.hidden_layers:\n",
    "            x = torch.relu(layer(x))\n",
    "        \n",
    "        # Outputs\n",
    "        pi = torch.softmax(self.pi(x), dim=1)  # Softmax for mixture weights\n",
    "        mu = self.mu(x)  # Mean of each Gaussian\n",
    "        log_sigma = self.log_sigma(x)  # Log standard deviation for stability\n",
    "        \n",
    "        return pi, mu, log_sigma\n",
    "\n",
    "# Negative log-likelihood loss for MDN with varying observation noise\n",
    "def mdn_loss_with_varying_noise(pi, mu, log_sigma, betahat, sebetahat):\n",
    "    sigma =torch.exp(log_sigma)  # Model predicted std (Gaussian std)\n",
    "    sebetahat = sebetahat.unsqueeze(1)  # Match the dimensions for broadcasting\n",
    "    total_sigma = torch.sqrt(sigma**2 + sebetahat**2)  # Combine with varying observation noise\n",
    "    m = torch.distributions.Normal(mu, total_sigma)\n",
    "    probs = m.log_prob(betahat.unsqueeze(1))  # Log probability of betahat under each Gaussian\n",
    "    log_probs = probs + torch.log(pi )  # Log-prob weighted by pi\n",
    "    nll = -torch.logsumexp(log_probs, dim=1)  # Logsumexp for numerical stability\n",
    "    return nll.mean()\n",
    "\n",
    "\n",
    "# Class to store the results\n",
    "class EmdnPosteriorMeanNorm:\n",
    "    def __init__(self, post_mean, post_mean2, post_sd, loss=0, model_param=None,\n",
    "                 location=None, scale=None, pi_np=None):\n",
    "        self.post_mean = post_mean\n",
    "        self.post_mean2 = post_mean2\n",
    "        self.post_sd = post_sd\n",
    "        self.loss = loss\n",
    "        self.model_param = model_param\n",
    "        self.location = location\n",
    "        self.scale = scale\n",
    "        self.pi_np = pi_np\n",
    "\n",
    "\n",
    "# Main function to train the model and compute posterior means, mean^2, and standard deviations\n",
    "def emdn_posterior_means(X, betahat, sebetahat, n_epochs=50 ,n_layers=4, n_gaussians=5, hidden_dim=64, batch_size=1024, lr=0.001, model_param=None):\n",
    "    # Standardize X\n",
    "    if X.ndim == 1:\n",
    "        X = X.reshape(-1, 1)\n",
    "    scaler = StandardScaler()\n",
    "    X_scaled = scaler.fit_transform(X)\n",
    "\n",
    "    # Create dataset and dataloader\n",
    "    dataset = DensityRegressionDataset(X_scaled, betahat, sebetahat)\n",
    "    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "    # Initialize model and optimizer\n",
    "    input_dim = X_scaled.shape[1]  # Number of input features\n",
    "    model = MDN(input_dim=input_dim, hidden_dim=hidden_dim, n_gaussians=n_gaussians,n_layers=n_layers)\n",
    "    optimizer = optim.Adam(model.parameters(), lr=lr)\n",
    "    if model_param is not None:\n",
    "        model.load_state_dict(model_param)\n",
    "\n",
    "    # Training loop\n",
    "    for epoch in range(n_epochs):\n",
    "        model.train()\n",
    "        running_loss = 0.0\n",
    "        for inputs, targets, noise_std in dataloader:\n",
    "            optimizer.zero_grad()\n",
    "            pi, mu, log_sigma = model(inputs)\n",
    "            loss = mdn_loss_with_varying_noise(pi, mu, log_sigma, targets, noise_std)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            running_loss += loss.item()\n",
    "        if (epoch + 1) % 10 == 0:\n",
    "            print(f'Epoch {epoch+1}/{n_epochs}, Loss: {running_loss/len(dataloader):.4f}')\n",
    "\n",
    "    # Once trained, generate posterior means for the entire dataset\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        # Convert the entire dataset into a batch for prediction\n",
    "        train_loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)\n",
    "        for X_batch, _, _ in train_loader:\n",
    "            pi, mu, log_sigma = model(X_batch)  # Predict pi, mu, and log_sigma\n",
    "            #final_loss = mdn_loss_with_varying_noise(pi, mu, log_sigma, betahat_batch, sebetahat_batch).item()\n",
    "\n",
    "\n",
    "    # Convert predictions to numpy arrays\n",
    "    pi_np = pi.numpy()\n",
    "    mu_np = mu.numpy()\n",
    "    log_sigma_np = log_sigma.numpy()\n",
    "\n",
    "    # Initialize arrays to store the results\n",
    "    post_mean = np.zeros(len(betahat))\n",
    "    post_mean2 = np.zeros(len(betahat))\n",
    "    post_sd = np.zeros(len(betahat))\n",
    "\n",
    "    # Estimate posterior means for each observation\n",
    "    for i in range(len(betahat)):\n",
    "        result = posterior_mean_norm(\n",
    "            betahat=np.array([betahat[i]]),\n",
    "            sebetahat=np.array([sebetahat[i]]),\n",
    "            log_pi=np.log(pi_np[i, :]),\n",
    "            location=mu_np[i, :],\n",
    "            scale=np.sqrt(np.exp(log_sigma_np[i, :]) ** 2 )\n",
    "        )\n",
    "        post_mean[i] = result.post_mean\n",
    "        post_mean2[i] = result.post_mean2\n",
    "        post_sd[i] = result.post_sd\n",
    "\n",
    "    model_param= model.state_dict()\n",
    "    # Return all three arrays: posterior mean, mean^2, and standard deviation\n",
    "    return EmdnPosteriorMeanNorm(post_mean,\n",
    "                                 post_mean2, \n",
    "                                 post_sd,\n",
    "                                 loss= running_loss,\n",
    "                                 model_param=model_param,\n",
    "                                 location = mu_np,\n",
    "                                 scale= np.sqrt(np.exp(log_sigma_np ) ** 2 ),\n",
    "                                 pi_np=pi_np)\n",
    "class PosteriorMeanNormResult:\n",
    "    def __init__(self, post_mean, post_mean2, post_sd):\n",
    "        self.post_mean = post_mean\n",
    "        self.post_mean2 = post_mean2\n",
    "        self.post_sd = post_sd\n",
    "\n",
    "def posterior_mean_norm(betahat, sebetahat, log_pi, location, scale):\n",
    "    # Convert to numpy arrays and ensure shapes\n",
    "    betahat = np.asarray(betahat).reshape(-1, 1)\n",
    "    sebetahat = np.asarray(sebetahat).reshape(-1, 1)\n",
    "    location = np.asarray(location).reshape(1, -1)\n",
    "    scale = np.asarray(scale).reshape(1, -1)\n",
    "    pi = np.exp(log_pi)\n",
    "    pi = pi / np.sum(pi)  # normalize just in case\n",
    "\n",
    "    # Total variance\n",
    "    total_var = scale**2 + sebetahat**2\n",
    "    post_var = (scale**2 * sebetahat**2) / total_var\n",
    "    post_mean = (scale**2 * betahat + sebetahat**2 * location) / total_var\n",
    "\n",
    "    w = pi * np.exp(-0.5 * ((betahat - location)**2 / total_var)) / np.sqrt(total_var)\n",
    "    w = w / np.sum(w)\n",
    "\n",
    "    post_mean_mix = np.sum(w * post_mean)\n",
    "    post_mean2_mix = np.sum(w * (post_mean**2 + post_var))\n",
    "    post_sd_mix = np.sqrt(post_mean2_mix - post_mean_mix**2)\n",
    "\n",
    "    return PosteriorMeanNormResult(post_mean_mix, post_mean2_mix, post_sd_mix)\n",
    "import numpy as np\n",
    "from sklearn.linear_model import LassoCV\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from scipy.special import logsumexp\n",
    "\n",
    "# Optional: for R-based ASH\n",
    "from rpy2.robjects import numpy2ri\n",
    "from rpy2.robjects.packages import importr\n",
    "numpy2ri.activate()\n",
    "ashr = importr(\"ashr\")\n",
    "\n",
    "# For MDN-based shrinkage\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "# --- Core Utility Functions ---\n",
    "def col_scale(X, with_mean=True, with_std=True):\n",
    "    scaler = StandardScaler(with_mean=with_mean, with_std=with_std)\n",
    "    X_scaled = scaler.fit_transform(X)\n",
    "    return X_scaled, scaler.scale_\n",
    "\n",
    "\n",
    "def get_data_loglik_normal(betahat, sebetahat, location, scale):\n",
    "    var = sebetahat[:, None] ** 2 + scale[None, :] ** 2\n",
    "    return -0.5 * (np.log(2 * np.pi * var) + (betahat[:, None] - location) ** 2 / var)\n",
    "\n",
    "\n",
    "def apply_log_sum_exp(data_loglik, log_pi):\n",
    "    combined_loglik = data_loglik + log_pi\n",
    "    return combined_loglik - logsumexp(combined_loglik, axis=1)[:, None]\n",
    "\n",
    "\n",
    "class PosteriorMeanNorm:\n",
    "    def __init__(self, post_mean, post_mean2, post_sd):\n",
    "        self.post_mean = post_mean\n",
    "        self.post_mean2 = post_mean2\n",
    "        self.post_sd = post_sd\n",
    "\n",
    "\n",
    "def posterior_mean_norm(betahat, sebetahat, log_pi, scale, location=None):\n",
    "    if location is None:\n",
    "        location = np.zeros_like(scale)\n",
    "\n",
    "    data_loglik = get_data_loglik_normal(betahat, sebetahat, location, scale)\n",
    "    log_post_assignment = apply_log_sum_exp(data_loglik, log_pi)\n",
    "\n",
    "    n, K = betahat.shape[0], scale.shape[0]\n",
    "    var = np.zeros((n, K))\n",
    "    for i in range(n):\n",
    "        if scale[0] == 0:\n",
    "            var[i, :] = np.concatenate(([0], 1 / ((1 / sebetahat[i]**2) + (1 / scale[1:]**2))))\n",
    "        else:\n",
    "            var[i, :] = 1 / ((1 / sebetahat[i]**2) + (1 / scale**2))\n",
    "\n",
    "    temp = np.zeros((n, K))\n",
    "    for i in range(n):\n",
    "        temp[i, :] = (var[i, :] / sebetahat[i]**2) * betahat[i] + location * (1 - var[i, :] / sebetahat[i]**2)\n",
    "\n",
    "    post_mean = np.sum(np.exp(log_post_assignment) * temp, axis=1)\n",
    "    post_mean2 = np.sum(np.exp(log_post_assignment) * (var + temp**2), axis=1)\n",
    "    post_sd = np.sqrt(post_mean2 - post_mean**2)\n",
    "    return PosteriorMeanNorm(post_mean, post_mean2, post_sd)\n",
    "\n",
    "\n",
    "# --- ASH Wrapper ---\n",
    "def call_r_ash_fit_all_with_postmean(beta, sigma):\n",
    "    from rpy2.rinterface_lib.sexp import NULLType\n",
    "    sebetahat = np.full_like(beta, sigma)\n",
    "    ash_obj = ashr.ash(betahat=beta, sebetahat=sebetahat, mixcompdist=\"normal\")\n",
    "    fitted_g = ash_obj.rx2(\"fitted_g\")\n",
    "    pi_r = np.array(fitted_g.rx2(\"pi\"), dtype=np.float32)\n",
    "    scale_r = np.array(fitted_g.rx2(\"sd\"), dtype=np.float32)\n",
    "    posterior_mean_r = ash_obj.rx2(\"result\").rx2(\"PosteriorMean\")\n",
    "    if isinstance(posterior_mean_r, NULLType):\n",
    "        raise RuntimeError(\"R ash() returned NULL for result$PosteriorMean\")\n",
    "    posterior_mean = np.array(posterior_mean_r, dtype=np.float32)\n",
    "    log_pi = np.log(np.clip(pi_r, 1e-12, 1.0))\n",
    "    return log_pi, scale_r, posterior_mean\n",
    "\n",
    "\n",
    "# --- Unified NASH Framework ---\n",
    "def nash_with_sideinfo(X, y, sideinfo=None, method=\"ash\", maxit=20, damping=0.8, eb_kwargs=None):\n",
    "    X = X.astype(np.float32)\n",
    "    y = y.astype(np.float32).flatten()\n",
    "    X, csd = col_scale(X)\n",
    "    y, ysd = col_scale(y.reshape(-1, 1))\n",
    "    y = y.flatten()\n",
    "\n",
    "    n, p = X.shape\n",
    "    sqn = np.sqrt(n).astype(np.float32)\n",
    "    beta = LassoCV(cv=5, max_iter=5000).fit(X, y).coef_\n",
    "\n",
    "    log_pi, scale, location = None, None, None\n",
    "    td_beta = []\n",
    "    model_param = None\n",
    "    pi0=0\n",
    "    for o in range(maxit):\n",
    "        r = y - X @ beta\n",
    "        beta_new = beta.copy()\n",
    "        betahat_list, sebetahat_list = [], []\n",
    "\n",
    "        for k in range(p):\n",
    "            xk = X[:, k]\n",
    "            r_k = r + xk * beta_new[k]\n",
    "            betahat_k = np.dot(xk, r_k) / n\n",
    "            sebetahat_k = max(np.std(r_k) / sqn, 1e-4)\n",
    "\n",
    "            if log_pi is None:\n",
    "                beta_k_new = betahat_k\n",
    "            else:\n",
    "                if method == \"ash\":\n",
    "                    pm = posterior_mean_norm(\n",
    "                    betahat=np.array([betahat_k]),\n",
    "                    sebetahat=np.array([sigma_0]) ,  \n",
    "                    log_pi=log_pi,\n",
    "                    scale=scale,\n",
    "                    location=location\n",
    "                    )\n",
    "                    beta_k_new = pm.post_mean[0]\n",
    "                else: \n",
    "                    pm = posterior_mean_norm(\n",
    "                    betahat=np.array([betahat_k]),\n",
    "                    sebetahat=np.array([sebetahat_k]),\n",
    "                    log_pi=np.log( result.pi_np[k, :]),\n",
    "                    location=result.location[k, :],\n",
    "                    scale=result.scale[k, :]\n",
    "                    )\n",
    "                    \n",
    "                    beta_k_new = pm.post_mean[0]\n",
    "\n",
    "            r += xk * (beta_new[k]  - beta_k_new)\n",
    "            beta_k_new = damping * beta_k_new + (1 - damping) * beta_new[k]\n",
    "            beta_new[k]=beta_k_new\n",
    "            betahat_list.append(betahat_k)\n",
    "            sebetahat_list.append(sebetahat_k)\n",
    "\n",
    "        betahat_arr = np.array(betahat_list)\n",
    "        avg_se = np.mean(sebetahat_list)\n",
    "        beta = beta_new.copy()\n",
    "        res_sq_final = np.sum((y - X @ beta ) ** 2)\n",
    "        sigma_0_term = np.dot(beta, beta - betahat_arr )\n",
    "        #print(log_pi)\n",
    "        if log_pi is None:\n",
    "            denom = n + p * (1 - pi0)\n",
    "        else:\n",
    "            if method == \"ash\":\n",
    "                pi0 =np.exp(log_pi[0])\n",
    "            else:\n",
    "                pi0=np.mean(result.pi_np[ :, 0])\n",
    "        if denom <= 1e-8:\n",
    "            denom = n + p * 1e-3  # avoid divide by 0 or small number\n",
    "        sigma_0 = np.sqrt(max((res_sq_final + sigma_0_term) / denom, 1e-8)) / sqn\n",
    "\n",
    "        res_sq = np.sum((y - X @ betahat_arr) ** 2)\n",
    "        drift_term = np.dot(betahat_arr,  betahat_arr) / np.sqrt(n + p)\n",
    "        s = np.sqrt((res_sq + drift_term)) / sqn\n",
    "\n",
    "        # Update drift_comp\n",
    "        drift_comp = (1 / (n / s**2 + 1 / sigma_0**2)) * (n / s**2)\n",
    "\n",
    "        # Update beta with ash\n",
    "        ash_input = drift_comp * betahat_arr + (1 - drift_comp) * beta\n",
    "        if method == \"ash\":\n",
    "            log_pi, scale, postmean = call_r_ash_fit_all_with_postmean(ash_input, sigma_0)\n",
    "            location = np.zeros_like(scale)\n",
    "            #beta = postmean.copy()\n",
    "\n",
    "        elif method == \"nash-mdn\":\n",
    "            result = emdn_posterior_means(sideinfo, ash_input, np.full_like(ash_input,avg_se), model_param=model_param, **(eb_kwargs or {}))\n",
    "            log_pi= result.pi_np\n",
    "            beta = result.post_mean\n",
    "            model_param = result.model_param\n",
    "\n",
    "        elif method == \"nash-cash\": \n",
    "            result = cash(ash_input, np.full_like(ash_input, avg_se), sideinfo)\n",
    "            #beta = result.post_mean\n",
    "\n",
    "        else:\n",
    "            raise ValueError(\"Unknown method: choose from 'ash', 'nash-mdn', or 'nash-cash'\")\n",
    "\n",
    "        td_beta.append(ysd * beta / csd)\n",
    "        print(f\"Iteration {o+1} complete\")\n",
    "    if method==\"ash\":\n",
    "        beta = postmean.copy() \n",
    "    else:\n",
    "        beta = result.post_mean\n",
    "    final_beta = ysd * beta / csd \n",
    "    return final_beta\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "97786671",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.metrics import mean_squared_error, mean_absolute_error\n",
    "from pathlib import Path\n",
    "\n",
    "# Set path\n",
    "base_path = Path(\"C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split\")\n",
    "result_path = Path(\"C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/results_realdata\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1c99f00",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_nn(dataset_name):\n",
    "    dataset_path = base_path / dataset_name\n",
    "    if not dataset_path.exists():\n",
    "        print(f\"Dataset folder '{dataset_name}' not found, skipping.\")\n",
    "        return\n",
    "\n",
    "    rmses, mads = [], []\n",
    "    for k in range(1, 11):\n",
    "        try:\n",
    "            X_train = pd.read_csv(dataset_path / f\"X_train{k}.csv\").values.astype(np.float32)\n",
    "            side_info = pd.read_csv(dataset_path / f\"X_train{k}.csv\").values.astype(np.float32)\n",
    "            y_train = pd.read_csv(dataset_path / f\"y_train{k}.csv\").values.astype(np.float32).flatten()\n",
    "            X_test = pd.read_csv(dataset_path / f\"X_test{k}.csv\").values.astype(np.float32)\n",
    "            y_test = pd.read_csv(dataset_path / f\"y_test{k}.csv\").values.astype(np.float32).flatten()\n",
    "        except FileNotFoundError as e:\n",
    "            print(f\"  Fold {k}: missing file — skipping this fold.\")\n",
    "            continue\n",
    "\n",
    "        # Move data to torch \n",
    "        # Define model\n",
    "        model = FeedforwardNN(X_train.shape[1]).to(device)\n",
    "        optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "        loss_fn = nn.MSELoss()\n",
    "\n",
    "        # Train model\n",
    "        fit_nash =  nash_with_sideinfo( X= X_train,\n",
    "                            y=y_train,\n",
    "                            sideinfo=side_info,\n",
    "                             maxit=10 )\n",
    "\n",
    "         \n",
    "        y_pred_test =  X_test @ fit_nash\n",
    "\n",
    "        # Compute metrics\n",
    "        rmse = np.sqrt(mean_squared_error(y_test, y_pred_test))\n",
    "        mad = mean_absolute_error(y_test, y_pred_test)\n",
    "        rmses.append(rmse)\n",
    "        mads.append(mad)\n",
    "        print(f\"  Fold {k}: RMSE = {rmse:.4f}, MAD = {mad:.4f}\")\n",
    "\n",
    "    # Save results if any folds succeeded\n",
    "    if rmses:\n",
    "        df = pd.DataFrame({'RMSE': rmses, 'MAD': mads})\n",
    "        df.to_csv(result_path / f\"{dataset_name}_pytorch_nn.csv\", index=False)\n",
    "        print(f\"{dataset_name}: mean RMSE = {np.mean(rmses):.4f}, mean MAD = {np.mean(mads):.4f}\\n\")\n",
    "    else:\n",
    "        print(f\"{dataset_name}: no valid folds to evaluate.\\n\")\n",
    "\n",
    "# Run for all datasets\n",
    "datasets = [\"Airpassenger\", \"SNP500\", \"spaRNA_seq\", \"TCGA\", \"GSE40279\"]\n",
    "for dname in datasets:\n",
    "    run_nn(dname)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
