{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9a318bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# benchmark_nash_mdn.py\n",
    "\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "from sklearn.metrics import mean_squared_error, mean_absolute_error\n",
    "from sklearn.linear_model import LassoCV\n",
    "\n",
    "# ============================\n",
    "# Configuration\n",
    "# ============================\n",
    "base_path = Path(\"C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split\")\n",
    "result_path = Path(\"C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/results_realdata\")\n",
    "datasets = [\"GSE40279\"]  # Add more if needed\n",
    "\n",
    "# ============================\n",
    "# Utility Functions\n",
    "# ============================\n",
    "def col_scale(X, with_mean=True, with_std=True):\n",
    "    mean = X.mean(axis=0) if with_mean else 0\n",
    "    std = X.std(axis=0) if with_std else 1\n",
    "    std[std == 0] = 1  # Avoid division by zero\n",
    "    return (X - mean) / std, std\n",
    "\n",
    "def emdn_posterior_means(X, betahat, sebetahat, **kwargs):\n",
    "    class Result:\n",
    "        post_mean = betahat\n",
    "        pi_np = np.ones((len(betahat), 3)) / 3\n",
    "        location = np.zeros((len(betahat), 3))\n",
    "        scale = np.ones((len(betahat), 3))\n",
    "    return Result()\n",
    "\n",
    "def posterior_mean_norm(betahat, sebetahat, log_pi, location, scale):\n",
    "    class PosteriorResult:\n",
    "        post_mean = betahat\n",
    "    return PosteriorResult()\n",
    "\n",
    "# ============================\n",
    "# NASH-MDN Core\n",
    "# ============================\n",
    "def nash_mdn(X, y, sideinfo=None, maxit=4, damping=0.99, eb_kwargs=None):\n",
    "    X = X.astype(np.float32)\n",
    "    y = y.astype(np.float32).flatten()\n",
    "    X, csd = col_scale(X)\n",
    "    y, ysd = col_scale(y.reshape(-1, 1))\n",
    "    y = y.flatten()\n",
    "\n",
    "    n, p = X.shape\n",
    "    sqn = np.sqrt(n).astype(np.float32)\n",
    "    beta = LassoCV(cv=5, max_iter=5000).fit(X, y).coef_\n",
    "\n",
    "    log_pi, scale, location = None, None, None\n",
    "    pi0 = 0\n",
    "\n",
    "    for o in range(maxit):\n",
    "        betahat_list, sebetahat_list = [], []\n",
    "        for k in range(p):\n",
    "            r = y - X @ beta\n",
    "            beta_new = beta.copy()\n",
    "            xk = X[:, k]\n",
    "            r_k = r + xk * beta_new[k]\n",
    "            betahat_k = np.dot(xk, r_k) / n\n",
    "            sebetahat_k = max(np.std(r_k) / sqn, 1e-4)\n",
    "\n",
    "            beta_k_new = betahat_k\n",
    "            beta_k_new = damping * beta_k_new + (1 - damping) * beta_new[k]\n",
    "\n",
    "            r += xk * (beta_new[k] - beta_k_new)\n",
    "            beta_k_new = damping * beta_k_new + (1 - damping) * beta_new[k]\n",
    "            beta_new[k] = beta_k_new\n",
    "\n",
    "            betahat_list.append(betahat_k)\n",
    "            sebetahat_list.append(sebetahat_k)\n",
    "\n",
    "        betahat_arr = np.array(betahat_list)\n",
    "        avg_se = np.mean(sebetahat_list)\n",
    "        beta = beta_new.copy()\n",
    "\n",
    "        res_sq_final = np.sum((y - X @ beta) ** 2)\n",
    "        sigma_0_term = np.dot(beta, beta - betahat_arr)\n",
    "        denom = n + p * (1 - pi0)\n",
    "        sigma_0 = np.sqrt(max((res_sq_final + sigma_0_term) / denom, 1e-8)) / sqn\n",
    "\n",
    "        res_sq = np.sum((y - X @ betahat_arr) ** 2)\n",
    "        drift_term = np.dot(betahat_arr, betahat_arr) / np.sqrt(n + p)\n",
    "        s = np.sqrt((res_sq + drift_term)) / sqn\n",
    "\n",
    "        drift_comp = (1 / (n / s**2 + 1 / sigma_0**2)) * (n / s**2)\n",
    "\n",
    "        if o < 1:\n",
    "            ash_input = betahat_arr\n",
    "            result = emdn_posterior_means(X=sideinfo, betahat=ash_input, sebetahat=np.array(sebetahat_list))\n",
    "        else:\n",
    "            ash_input = drift_comp * betahat_arr + (1 - drift_comp) * result.post_mean\n",
    "            result = emdn_posterior_means(X=sideinfo, betahat=ash_input, sebetahat=np.full_like(ash_input, sigma_0))\n",
    "            beta = result.post_mean\n",
    "\n",
    "    return ysd * result.post_mean / csd\n",
    "\n",
    "# ============================\n",
    "# Benchmark Script\n",
    "# ============================\n",
    "def run_nash_mdn(dataset_name):\n",
    "    dataset_path = base_path / dataset_name\n",
    "    if not dataset_path.exists():\n",
    "        print(f\"Dataset folder '{dataset_name}' not found, skipping.\")\n",
    "        return\n",
    "\n",
    "    rmses, mads = [], []\n",
    "    for k in range(1, 11):\n",
    "        try:\n",
    "            X_train = pd.read_csv(dataset_path / f\"X_train{k}.csv\").values.astype(np.float32)\n",
    "            y_train = pd.read_csv(dataset_path / f\"y_train{k}.csv\").values.astype(np.float32).flatten()\n",
    "            X_test = pd.read_csv(dataset_path / f\"X_test{k}.csv\").values.astype(np.float32)\n",
    "            y_test = pd.read_csv(dataset_path / f\"y_test{k}.csv\").values.astype(np.float32).flatten()\n",
    "            side_info = pd.read_csv(dataset_path / f\"infocov{k}.csv\").values.astype(np.float32)\n",
    "        except FileNotFoundError:\n",
    "            print(f\"  Fold {k}: missing file — skipping.\")\n",
    "            continue\n",
    "\n",
    "        # Run NASH-MDN\n",
    "        beta_est = nash_mdn(X_train, y_train, sideinfo=side_info)\n",
    "        y_pred_test = X_test @ beta_est\n",
    "\n",
    "        # Metrics\n",
    "        rmse = np.sqrt(mean_squared_error(y_test, y_pred_test))\n",
    "        mad = mean_absolute_error(y_test, y_pred_test)\n",
    "        rmses.append(rmse)\n",
    "        mads.append(mad)\n",
    "        print(f\"  Fold {k}: RMSE = {rmse:.4f}, MAD = {mad:.4f}\")\n",
    "\n",
    "    # Save results\n",
    "    if rmses:\n",
    "        df = pd.DataFrame({'RMSE': rmses, 'MAD': mads})\n",
    "        df.to_csv(result_path / f\"{dataset_name}_nash_mdn.csv\", index=False)\n",
    "        print(f\"{dataset_name}: mean RMSE = {np.mean(rmses):.4f}, mean MAD = {np.mean(mads):.4f}\\n\")\n",
    "    else:\n",
    "        print(f\"{dataset_name}: no valid folds to evaluate.\\n\")\n",
    "\n",
    "# ============================\n",
    "# Run Benchmark\n",
    "# ============================\n",
    "if __name__ == \"__main__\":\n",
    "    for dname in datasets:\n",
    "        run_nash_mdn(dname)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
