{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "24b8334c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "os.environ[\"R_HOME\"] = r\"C:\\Program Files\\R\\R-4.4.2\"  # Use raw string for Windows paths\n",
    "sys.path.append(r\"c:\\Document\\Serieux\\Travail\\python_work\\cEBNM_torch\\py\")\n",
    "\n",
    "\n",
    "from utils import * \n",
    "# Core libraries\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    " \n",
    "\n",
    "# scikit-learn\n",
    "from sklearn.linear_model import LassoCV, RidgeCV\n",
    "from sklearn.metrics import mean_squared_error\n",
    "\n",
    "# rpy2 for R integration\n",
    "import rpy2.robjects as ro\n",
    "from rpy2.robjects import numpy2ri, r\n",
    "from rpy2.robjects.packages import importr\n",
    "\n",
    "# Activate NumPy-to-R conversion\n",
    "numpy2ri.activate()\n",
    "ashr = importr(\"ashr\")\n",
    "\n",
    "# torch (if used later in the project)\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    " \n",
    "# Local utilities (if you have a utils.py file)\n",
    "from utils_mix import *\n",
    "from empirical_mdn import emdn_posterior_means\n",
    "from utils_mix import *\n",
    "from numerical_routine import *\n",
    "from distribution_operation import *\n",
    "from posterior_computation import posterior_mean_norm\n",
    "from cash_solver import Cash_posterior_means\n",
    "from ash import call_r_ash_fit_all_with_postmean\n",
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "from sklearn.metrics import mean_squared_error, mean_absolute_error\n",
    "from sklearn.linear_model import LassoCV\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "eea6ecff",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def col_scale(X, with_mean=True, with_std=True):\n",
    "    scaler = StandardScaler(with_mean=with_mean, with_std=with_std)\n",
    "    X_scaled = scaler.fit_transform(X)\n",
    "    return X_scaled, scaler.scale_\n",
    "\n",
    "\n",
    "def call_r_ash_fit_all_with_postmean(beta, sigma):\n",
    "    \"\"\"\n",
    "    Calls R's ash function and returns:\n",
    "    - log mixture weights\n",
    "    - mixture standard deviations\n",
    "    - posterior mean of beta\n",
    "    \"\"\"\n",
    "    from rpy2.rinterface_lib.sexp import NULLType\n",
    "\n",
    "    sebetahat = np.full_like(beta, sigma)\n",
    "    ash_obj = ashr.ash(betahat=beta, sebetahat=sebetahat, mixcompdist=\"normal\")\n",
    "\n",
    "    fitted_g = ash_obj.rx2(\"fitted_g\")\n",
    "    pi_r = np.array(fitted_g.rx2(\"pi\"), dtype=np.float32)\n",
    "    scale_r = np.array(fitted_g.rx2(\"sd\"), dtype=np.float32)\n",
    "\n",
    "    posterior_mean_r = ash_obj.rx2(\"result\").rx2(\"PosteriorMean\")\n",
    "    if isinstance(posterior_mean_r, NULLType):\n",
    "        raise RuntimeError(\"R ash() returned NULL for result$PosteriorMean\")\n",
    "    posterior_mean = np.array(posterior_mean_r, dtype=np.float32)\n",
    "\n",
    "    log_pi = np.log(np.clip(pi_r, 1e-12, 1.0))\n",
    "    return log_pi, scale_r, posterior_mean\n",
    "\n",
    "\n",
    "def get_data_loglik_normal(betahat, sebetahat, location, scale):\n",
    "    var = sebetahat[:, None] ** 2 + scale[None, :] ** 2\n",
    "    return -0.5 * (np.log(2 * np.pi * var) + (betahat[:, None] - location) ** 2 / var)\n",
    "\n",
    "\n",
    "def apply_log_sum_exp(data_loglik, log_pi):\n",
    "    combined_loglik = data_loglik + log_pi\n",
    "    return combined_loglik - logsumexp(combined_loglik, axis=1)[:, None]\n",
    "\n",
    "\n",
    "class PosteriorMeanNorm:\n",
    "    def __init__(self, post_mean, post_mean2, post_sd):\n",
    "        self.post_mean = post_mean\n",
    "        self.post_mean2 = post_mean2\n",
    "        self.post_sd = post_sd\n",
    "\n",
    "\n",
    "def posterior_mean_norm(betahat, sebetahat, log_pi, scale, location=None):\n",
    "    if location is None:\n",
    "        location = np.zeros_like(scale)\n",
    "\n",
    "    data_loglik = get_data_loglik_normal(betahat, sebetahat, location, scale)\n",
    "    log_post_assignment = apply_log_sum_exp(data_loglik, log_pi)\n",
    "\n",
    "    n, K = betahat.shape[0], scale.shape[0]\n",
    "    var = np.zeros((n, K))\n",
    "\n",
    "    if scale[0] == 0:\n",
    "        for i in range(n):\n",
    "            var[i, :] = np.concatenate(([0], 1 / ((1 / sebetahat[i]**2) + (1 / scale[1:]**2))))\n",
    "    else:\n",
    "        for i in range(n):\n",
    "            var[i, :] = 1 / ((1 / sebetahat[i]**2) + (1 / scale**2))\n",
    "\n",
    "    temp = np.zeros((n, K))\n",
    "    for i in range(n):\n",
    "        temp[i, :] = (var[i, :] / sebetahat[i]**2) * betahat[i] + location * (1 - var[i, :] / sebetahat[i]**2)\n",
    "\n",
    "    post_mean = np.sum(np.exp(log_post_assignment) * temp, axis=1)\n",
    "    post_mean2 = np.sum(np.exp(log_post_assignment) * (var + temp**2), axis=1)\n",
    "    post_sd = np.sqrt(post_mean2 - post_mean**2)\n",
    "    return PosteriorMeanNorm(post_mean, post_mean2, post_sd)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "70f142d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def nash_mdn(X, y, sideinfo=None, maxit=4, damping=0.99, eb_kwargs=None):\n",
    "    k=0 \n",
    "    X = X.astype(np.float32)\n",
    "    y = y.astype(np.float32).flatten()\n",
    "    X, csd = col_scale(X)\n",
    "    y, ysd = col_scale(y.reshape(-1, 1))\n",
    "    y = y.flatten()\n",
    "\n",
    "    n, p = X.shape\n",
    "    sqn = np.sqrt(n).astype(np.float32)\n",
    "    beta = LassoCV(cv=5, max_iter=5000).fit(X, y).coef_\n",
    "\n",
    "    log_pi, scale, location = None, None, None\n",
    "    td_beta = []\n",
    "    model_param = None\n",
    "    pi0=0\n",
    "    \n",
    "    for o in range(maxit):\n",
    "        betahat_list, sebetahat_list = [], []\n",
    "        for k in range(p):\n",
    "            r = y - X @ beta\n",
    "            beta_new = beta.copy()\n",
    "            xk = X[:, k]\n",
    "            r_k = r + xk * beta_new[k]\n",
    "            betahat_k = np.dot(xk, r_k) / n\n",
    "            sebetahat_k = max(np.std(r_k) / sqn, 1e-4)\n",
    "            if 0<1: \n",
    "\n",
    "                beta_k_new = betahat_k\n",
    "                beta_k_new = damping * beta_k_new + (1 - damping) * beta_new[k]\n",
    "             \n",
    "            else :\n",
    "                print(o)\n",
    "                pm = posterior_mean_norm(\n",
    "                    betahat=np.array([betahat_k]),\n",
    "                    sebetahat=np.array([sebetahat_k]),\n",
    "                    log_pi=np.log( np.clip( result.pi_np[k, :], 1e-12, 1.0)),\n",
    "                    location=result.location[k, :],\n",
    "                    scale=result.scale[k, :]\n",
    "                    )\n",
    "             \n",
    "                beta_k_new = pm.post_mean[0]\n",
    "\n",
    "            r += xk * (beta_new[k]  - beta_k_new)\n",
    "            beta_k_new = damping * beta_k_new + (1 - damping) * beta_new[k]\n",
    "            beta_new[k]=beta_k_new\n",
    "            betahat_list.append(betahat_k)\n",
    "            sebetahat_list.append(sebetahat_k) \n",
    "\n",
    "\n",
    "        betahat_arr = np.array(betahat_list)\n",
    "        avg_se = np.mean(sebetahat_list)\n",
    "        beta = beta_new.copy()\n",
    "        res_sq_final = np.sum((y - X @ beta ) ** 2)\n",
    "        sigma_0_term = np.dot(beta, beta - betahat_arr )   \n",
    " \n",
    "        beta = beta_new.copy()\n",
    "        res_sq_final = np.sum((y - X @ beta ) ** 2)\n",
    "        sigma_0_term = np.dot(beta, beta - betahat_arr )\n",
    "        denom = n + p * (1 - pi0)\n",
    "        sigma_0 = np.sqrt(max((res_sq_final + sigma_0_term) / denom, 1e-8)) / sqn\n",
    "\n",
    "        res_sq = np.sum((y - X @ betahat_arr) ** 2)\n",
    "        drift_term = np.dot(betahat_arr,  betahat_arr) / np.sqrt(n + p)\n",
    "        s = np.sqrt((res_sq + drift_term)) / sqn\n",
    "\n",
    "        # Update drift_comp\n",
    "        drift_comp = (1 / (n / s**2 + 1 / sigma_0**2)) * (n / s**2)\n",
    "        if o<1: \n",
    "            # Update beta with ash\n",
    "            ash_input = betahat_arr   \n",
    "            result =emdn_posterior_means(X= sideinfo,\n",
    "                                            betahat= ash_input,\n",
    "                                            sebetahat=np.array(sebetahat_list)   )\n",
    "            emdn_posterior_means(X= sideinfo,\n",
    "                                          betahat= ash_input,\n",
    "                                          sebetahat= np.full_like(ash_input,sigma_0), **(eb_kwargs or {}))\n",
    "        else :\n",
    "            ash_input = drift_comp * betahat_arr + (1 - drift_comp) * result.post_mean\n",
    "            result =emdn_posterior_means(X= sideinfo,\n",
    "                                          betahat= ash_input,\n",
    "                                          sebetahat=np.full_like(ash_input,sigma_0)   )\n",
    "            log_pi= result.pi_np\n",
    "            beta =   result.post_mean\n",
    "\n",
    "\n",
    "  \n",
    "    out=   ysd * result.post_mean/ csd\n",
    "    return out\n",
    " \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d17063a7",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "could not convert string to float: 'OpenSea'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[4], line 13\u001b[0m\n\u001b[0;32m     11\u001b[0m X_test \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mread_csv(dataset_path \u001b[38;5;241m/\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_test\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.csv\u001b[39m\u001b[38;5;124m\"\u001b[39m)\u001b[38;5;241m.\u001b[39mvalues\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[0;32m     12\u001b[0m y_test \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mread_csv(dataset_path \u001b[38;5;241m/\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my_test\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.csv\u001b[39m\u001b[38;5;124m\"\u001b[39m)\u001b[38;5;241m.\u001b[39mvalues\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32)\u001b[38;5;241m.\u001b[39mflatten()\n\u001b[1;32m---> 13\u001b[0m df \u001b[38;5;241m=\u001b[39m \u001b[43mpd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread_csv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset_path\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m/\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minfocov\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mk\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m.csv\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mastype\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat32\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     14\u001b[0m df[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlocation\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m df[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlocation\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mstr\u001b[38;5;241m.\u001b[39mreplace(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m     15\u001b[0m df[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpg_id\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m df[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpg_id\u001b[39m\u001b[38;5;124m\"\u001b[39m]\u001b[38;5;241m.\u001b[39mstr\u001b[38;5;241m.\u001b[39mreplace(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
      "\u001b[1;31mValueError\u001b[0m: could not convert string to float: 'OpenSea'"
     ]
    }
   ],
   "source": [
    "dataset_name=[\"GSE40279\"] \n",
    "rmses, mads = [], []\n",
    "for k in range(1, 11):\n",
    "        try:\n",
    "            base_path = Path(\"C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split\")\n",
    "            result_path = Path(\"C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/results_realdata\")\n",
    "            datasets = [\"GSE40279\"] \n",
    "            dataset_path = Path(\"C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/GSE40279\")\n",
    "            X_train = pd.read_csv(dataset_path / f\"X_train{k}.csv\").values.astype(np.float32)\n",
    "            y_train = pd.read_csv(dataset_path / f\"y_train{k}.csv\").values.astype(np.float32).flatten()\n",
    "            X_test = pd.read_csv(dataset_path / f\"X_test{k}.csv\").values.astype(np.float32)\n",
    "            y_test = pd.read_csv(dataset_path / f\"y_test{k}.csv\").values.astype(np.float32).flatten()\n",
    "            df = pd.read_csv(dataset_path / f\"infocov{k}.csv\").values.astype(np.float32)\n",
    "            df[\"location\"] = df[\"location\"].str.replace('\"', '')\n",
    "            df[\"cpg_id\"] = df[\"cpg_id\"].str.replace('\"', '')\n",
    "\n",
    "# Get all unique location types\n",
    "            location_types = sorted(df[\"location\"].unique())\n",
    "\n",
    "# One-hot encode location\n",
    "            one_hot = pd.get_dummies(df[\"location\"])\n",
    "\n",
    "# Concatenate CpG IDs with the one-hot encoded data\n",
    "            result_df = pd.concat([df[\"cpg_id\"], one_hot], axis=1)\n",
    "            sideinfo_array = df.to_numpy(dtype=np.float32)\n",
    "            mu= combined = np.concatenate([ y_test,y_train]).mean()\n",
    "            y_test=  y_test-mu\n",
    "            y_train=y_train-mu\n",
    "# Now pass sideinfo to your model:\n",
    "            beta_est = nash_mdn(X=X_train, y=y_test, sideinfo=sideinfo_array, maxit=2)\n",
    "\n",
    "            y_pred_test = X_test @ beta_est\n",
    "\n",
    "        # Metrics\n",
    "            rmse = np.sqrt(mean_squared_error(y_test, y_pred_test))\n",
    "            mad = mean_absolute_error(y_test, y_pred_test)\n",
    "            rmses.append(rmse)\n",
    "            mads.append(mad)\n",
    "            print(f\"  Fold {k}: RMSE = {rmse:.4f}, MAD = {mad:.4f}\")\n",
    "\n",
    "    # Save results\n",
    "            if rmses:\n",
    "                df = pd.DataFrame({'RMSE': rmses, 'MAD': mads})\n",
    "                df.to_csv(result_path / f\"{dataset_name}_nash_mdn.csv\", index=False)\n",
    "                print(f\"{dataset_name}: mean RMSE = {np.mean(rmses):.4f}, mean MAD = {np.mean(mads):.4f}\\n\")\n",
    "            else:\n",
    "                print(f\"{dataset_name}: no valid folds to evaluate.\\n\")\n",
    "\n",
    "# View result\n",
    "            print(result_df.head())\n",
    "        except FileNotFoundError:\n",
    "            print(f\"  Fold {k}: missing file — skipping.\")\n",
    "            continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07332c88",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "could not convert string to float: 'OpenSea'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[8], line 9\u001b[0m\n\u001b[0;32m      7\u001b[0m X_test \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mread_csv(dataset_path \u001b[38;5;241m/\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mX_test\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.csv\u001b[39m\u001b[38;5;124m\"\u001b[39m)\u001b[38;5;241m.\u001b[39mvalues\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[0;32m      8\u001b[0m y_test \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mread_csv(dataset_path \u001b[38;5;241m/\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my_test\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mk\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.csv\u001b[39m\u001b[38;5;124m\"\u001b[39m)\u001b[38;5;241m.\u001b[39mvalues\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32)\u001b[38;5;241m.\u001b[39mflatten()\n\u001b[1;32m----> 9\u001b[0m side_info \u001b[38;5;241m=\u001b[39m \u001b[43mpd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread_csv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset_path\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m/\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43minfocov\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mk\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m.csv\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mastype\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat32\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[1;31mValueError\u001b[0m: could not convert string to float: 'OpenSea'"
     ]
    }
   ],
   "source": [
    "base_path = Path(\"C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split\")\n",
    "result_path = Path(\"C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/results_realdata\")\n",
    "datasets = [\"GSE40279\"] \n",
    "dataset_path = Path(\"C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/GSE40279\")\n",
    "X_train = pd.read_csv(dataset_path / f\"X_train{k}.csv\").values.astype(np.float32)\n",
    "y_train = pd.read_csv(dataset_path / f\"y_train{k}.csv\").values.astype(np.float32).flatten()\n",
    "X_test = pd.read_csv(dataset_path / f\"X_test{k}.csv\").values.astype(np.float32)\n",
    "y_test = pd.read_csv(dataset_path / f\"y_test{k}.csv\").values.astype(np.float32).flatten()\n",
    "df = pd.read_csv(dataset_path / f\"infocov{k}.csv\").values.astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "458b0726",
   "metadata": {},
   "outputs": [],
   "source": [
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a72d2b10",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"location\"] = df[\"location\"].str.replace('\"', '')\n",
    "df[\"cpg_id\"] = df[\"cpg_id\"].str.replace('\"', '')\n",
    "\n",
    "# Get all unique location types\n",
    "location_types = sorted(df[\"location\"].unique())\n",
    "\n",
    "# One-hot encode location\n",
    "one_hot = pd.get_dummies(df[\"location\"])\n",
    "\n",
    "# Concatenate CpG IDs with the one-hot encoded data\n",
    "result_df = pd.concat([df[\"cpg_id\"], one_hot], axis=1)\n",
    "\n",
    "# View result\n",
    "print(result_df.head())\n",
    "\n",
    "sideinfo_array = df.to_numpy(dtype=np.float32)\n",
    "mu= combined = np.concatenate([ y_test,y_train]).mean()\n",
    "y_test=  y_test-mu\n",
    "y_train=y_train-mu\n",
    "# Now pass sideinfo to your model:\n",
    "beta_est = nash_mdn(X=X_train, y=y_test, sideinfo=sideinfo_array, maxit=2)\n",
    "\n",
    " y_pred_test = X_test @ beta_est\n",
    "\n",
    "        # Metrics\n",
    "        rmse = np.sqrt(mean_squared_error(y_test, y_pred_test))\n",
    "        mad = mean_absolute_error(y_test, y_pred_test)\n",
    "        rmses.append(rmse)\n",
    "        mads.append(mad)\n",
    "        print(f\"  Fold {k}: RMSE = {rmse:.4f}, MAD = {mad:.4f}\")\n",
    "\n",
    "    # Save results\n",
    "    if rmses:\n",
    "        df = pd.DataFrame({'RMSE': rmses, 'MAD': mads})\n",
    "        df.to_csv(result_path / f\"{dataset_name}_nash_mdn.csv\", index=False)\n",
    "        print(f\"{dataset_name}: mean RMSE = {np.mean(rmses):.4f}, mean MAD = {np.mean(mads):.4f}\\n\")\n",
    "    else:\n",
    "        print(f\"{dataset_name}: no valid folds to evaluate.\\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4560ebf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name=[\"GSE40279\"] \n",
    "rmses, mads = [], []\n",
    "for k in range(1, 11):\n",
    "        try:\n",
    "            base_path = Path(\"C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split\")\n",
    "            result_path = Path(\"C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/results_realdata\")\n",
    "            datasets = [\"GSE40279\"] \n",
    "            dataset_path = Path(\"C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/GSE40279\")\n",
    "            X_train = pd.read_csv(dataset_path / f\"X_train{k}.csv\").values.astype(np.float32)\n",
    "            y_train = pd.read_csv(dataset_path / f\"y_train{k}.csv\").values.astype(np.float32).flatten()\n",
    "            X_test = pd.read_csv(dataset_path / f\"X_test{k}.csv\").values.astype(np.float32)\n",
    "            y_test = pd.read_csv(dataset_path / f\"y_test{k}.csv\").values.astype(np.float32).flatten()\n",
    "            df = pd.read_csv(dataset_path / f\"infocov{k}.csv\").values.astype(np.float32)\n",
    "            df[\"location\"] = df[\"location\"].str.replace('\"', '')\n",
    "            df[\"cpg_id\"] = df[\"cpg_id\"].str.replace('\"', '')\n",
    "\n",
    "# Get all unique location types\n",
    "            location_types = sorted(df[\"location\"].unique())\n",
    "\n",
    "# One-hot encode location\n",
    "            one_hot = pd.get_dummies(df[\"location\"])\n",
    "\n",
    "# Concatenate CpG IDs with the one-hot encoded data\n",
    "            result_df = pd.concat([df[\"cpg_id\"], one_hot], axis=1)\n",
    "            sideinfo_array = df.to_numpy(dtype=np.float32)\n",
    "            mu= combined = np.concatenate([ y_test,y_train]).mean()\n",
    "            y_test=  y_test-mu\n",
    "            y_train=y_train-mu\n",
    "# Now pass sideinfo to your model:\n",
    "            beta_est = nash_mdn(X=X_train, y=y_train, sideinfo=sideinfo_array, maxit=2)\n",
    "\n",
    "            y_pred_test = X_test @ beta_est\n",
    "\n",
    "        # Metrics\n",
    "            rmse = np.sqrt(mean_squared_error(y_test, y_pred_test))\n",
    "            mad = mean_absolute_error(y_test, y_pred_test)\n",
    "            rmses.append(rmse)\n",
    "            mads.append(mad)\n",
    "            print(f\"  Fold {k}: RMSE = {rmse:.4f}, MAD = {mad:.4f}\")\n",
    "\n",
    "    # Save results\n",
    "            if rmses:\n",
    "                df = pd.DataFrame({'RMSE': rmses, 'MAD': mads})\n",
    "                df.to_csv(result_path / f\"{dataset_name}_nash_mdn.csv\", index=False)\n",
    "                print(f\"{dataset_name}: mean RMSE = {np.mean(rmses):.4f}, mean MAD = {np.mean(mads):.4f}\\n\")\n",
    "            else:\n",
    "                print(f\"{dataset_name}: no valid folds to evaluate.\\n\")\n",
    "\n",
    "# View result \n",
    "        except FileNotFoundError:\n",
    "            print(f\"  Fold {k}: missing file — skipping.\")\n",
    "            continue"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
