# --- SETUP ---
rm(list = ls())
library(dplyr)
library(quantmod)
library(zoo)
library(glmnet)
library(ipflasso)
library(ebmr.alpha)
library(mr.ash.alpha)
library(ashr)
library(nnet)
library(TCGAbiolinks)
library(clusterProfiler)
library(org.Hs.eg.db)

# --- INPUT ---
load("C:/Document/Serieux/Travail/Data_analysis_and_papers/data_for_nash_experiment/days_to_death/deathpred.RData")
X_full <- data$X
target_index <- 1056
y_full <- X_full[, target_index]
X_full <- X_full[, -target_index]

# --- KEGG Annotation ---
gene_symbols <- colnames(X_full)
symbol_to_entrez <- bitr(gene_symbols, fromType = "SYMBOL", toType = "ENTREZID", OrgDb = "org.Hs.eg.db")
ekegg <- enrichKEGG(gene = symbol_to_entrez$ENTREZID, organism = 'hsa', pvalueCutoff = 0.05)
kegg_pathways <- setNames(as.list(ekegg@result$geneID), ekegg@result$ID)
entrez_to_symbol <- setNames(symbol_to_entrez$SYMBOL, symbol_to_entrez$ENTREZID)

# Build group list
used <- rep(FALSE, length(gene_symbols))
group_vector <- rep(NA_integer_, length(gene_symbols))
blocks_list <- list()
group_counter <- 1
for (path in names(kegg_pathways)) {
  symbols <- unique(entrez_to_symbol[unlist(strsplit(kegg_pathways[[path]], "/"))])
  idx <- which(gene_symbols %in% symbols & !used)
  if (length(idx) > 0) {
    blocks_list[[path]] <- idx
    group_vector[idx] <- group_counter
    used[idx] <- TRUE
    group_counter <- group_counter + 1
  }
}
# Assign ungrouped genes to one block
leftover <- which(is.na(group_vector))
if (length(leftover) > 0) {
  blocks_list[["Other"]] <- leftover
  group_vector[leftover] <- group_counter
}
stopifnot(!any(is.na(group_vector)))
stopifnot(length(group_vector) == ncol(X_full))
stopifnot(all(table(unlist(blocks_list)) == 1))  # unique assignment

# Save group vector
write.csv(group_vector, "C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/TCGA/infocov.csv", row.names = FALSE)

# --- EXPERIMENT LOOP ---


if (file.exists("C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/results_realdata/TCGA.RData")){
  load("C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/results_realdata/TCGA.RData")
}else{

  lt <- list()
}
for (k in (length(lt)+1):10) {
  set.seed(k)
  y <- y_full
  X <- X_full
  idx_test <- sample(1:nrow(X), floor(0.2 * nrow(X)))

  # --- IPF-LASSO ---
  fit_ipf <- cvr.ipflasso(
    X = X[-idx_test, ],
    Y = y[-idx_test],
    family = "gaussian",
    type.measure = "mse",
    alpha = 1,
    blocks = blocks_list,
    pf = rep(1, length(blocks_list)),
    nfolds = 5,
    ncv = 5
  )
  coef_ipf <- fit_ipf$coeff[, fit_ipf$ind.bestlambda]
  pred_ipf <- coef_ipf[1] + X[idx_test, ] %*% coef_ipf[-1]



  nash_noinfo_dynamic_td= function(X, y, maxit = 100, tol=10e-3){


    y.fit.ebr = ebmr(X , y , maxiter = 20, ebnv_fn = ebnv.pm)
    y.fit.nash=y.fit.ebr
    elbo=c(-Inf)
    for ( k in 1:maxit){


      tt= ash( y.fit.nash$mu,  (y.fit.nash$Sigma_diag))


      elbo=c(elbo,
             y.fit.nash$elbo[length(y.fit.nash$elbo)]-tt$loglik)

      if(  (k>1) &   (elbo[k+1]-elbo[k ]< tol) ){

        break
      }
      y.fit.nash= ebmr.update(y.fit.nash,
                              mu0=tt$result$PosteriorMean,
                              maxiter = 20)
    }

    y.fit.nash$elbo=elbo
    y.fit.nash$b= tt$result$PosteriorMean
    return( y.fit.nash)
  }
  fitmnash= nash_noinfo_dynamic_td(X=X[-idx_test, ],
                                   y= y[-idx_test] )


  # --- Regular GLMNET models ---
  fit_lasso <- cv.glmnet(X[-idx_test, ], y[-idx_test], alpha = 1)
  fit_enet  <- cv.glmnet(X[-idx_test, ], y[-idx_test], alpha = 0.5)
  fit_ridge <- cv.glmnet(X[-idx_test, ], y[-idx_test], alpha = 0)

  pred_lasso <- predict(fit_lasso, newx = X[idx_test, ])
  pred_enet  <- predict(fit_enet,  newx = X[idx_test, ])
  pred_ridge <- predict(fit_ridge, newx = X[idx_test, ])

  # --- mr.ash ---
  fit_mrash <- mr.ash(X[-idx_test, ], y[-idx_test])
  pred_mrash <- X[idx_test, ] %*% fit_mrash$beta



  # --- RMSE ---
  rmse <- function(y, yt) sqrt(mean((y - yt)^2))
  res <- c(
    rmse(y[idx_test], pred_lasso),
    rmse(y[idx_test], pred_enet),
    rmse(y[idx_test], pred_ridge),
    rmse(y[idx_test], pred_mrash),
    rmse(y[idx_test], pred_ipf),
    rmse(y[  idx_test],  X[ idx_test, ]%*%fitmnash$mu) # Placeholder for Nash if implemented later
  )

  name <- c("Lasso", "Enet", "Ridge", "MRash", "IPF", "Nash")

  lt[[k]] <- list(rmse = res, name = name)

  # --- Save data split ---
  write.csv(y[-idx_test], paste0("C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/TCGA/y_train", k, ".csv"), row.names = FALSE)
  write.csv(X[-idx_test, ], paste0("C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/TCGA/X_train", k, ".csv"), row.names = FALSE)
  write.csv(y[idx_test],  paste0("C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/TCGA/y_test", k, ".csv"), row.names = FALSE)
  write.csv(X[idx_test, ], paste0("C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/TCGA/X_test", k, ".csv"), row.names = FALSE)

  cat(sprintf("Fold %d RMSEs:\n", k))
  for (i in seq_along(name)) {
    cat(sprintf("%s: %.4f\n", name[i], res[i]))
  }

  save(lt, file = "C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/results_realdata/TCGA.RData")
}

# --- Save results ---
