rm(list=ls())

library(glmnet)
library(readr)
library(dplyr)
library(stringr)
library(data.table)
library(minfi)
library(dplyr)
library(quantmod)
library(zoo)
library(glmnet)
library(ipflasso)
library(ebmr.alpha)
library(mr.ash.alpha)
library(ashr)
library(IlluminaHumanMethylation450kanno.ilmn12.hg19)

sample <- read_csv("C:/Document/Serieux/Travail/Data_analysis_and_papers/data_for_nash_experiment/age_clock/sample.csv")
sample <- sample %>%
  mutate(age = str_extract(Title, "\\d+") %>% as.numeric())

age=sample$age
DNAm= fread("C:/Document/Serieux/Travail/Data_analysis_and_papers/data_for_nash_experiment/age_clock/GSE40279_average_beta.txt")
DNAm[1:10,1:10]
cpg= DNAm[,1]
DNAm= as.matrix(t(DNAm[,-1]))
cpg_ids <- cpg$ID_REF

annotationdf <- getAnnotation(IlluminaHumanMethylation450kanno.ilmn12.hg19)
annotationdf= data.frame(relation = annotationdf$Relation_to_Island ,
                         cpg= annotationdf$Name)
# Filter annotationdf to keep only matching CpGs

annotationdf_filtered <- annotationdf[annotationdf$cpg %in% cpg_ids, ]

# Also filter cpg to keep only matching entries
DNAm=DNAm[,cpg$ID_REF %in% annotationdf_filtered$cpg]
cpg_filtered <- cpg[cpg$ID_REF %in% annotationdf_filtered$cpg, ]

# Now reorder annotationdf to match the order in cpg_filtered
# This assumes both are data.frames or data.tables
annotationdf_ordered <- annotationdf_filtered[match(cpg_filtered$ID_REF, annotationdf_filtered$cpg), ]

# Optional sanity check:
stopifnot(all(cpg_filtered$ID_REF == annotationdf_ordered$cpg))

annotationdf_ordered[1:10,]




write.csv(annotationdf_ordered, "C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/GSE40279/infocov.csv", row.names = FALSE)

relation_labels <- annotationdf_ordered$relation

# Create blocks_list: a named list where each element is a vector of column indices in X
unique_groups <- unique(relation_labels)
blocks_list <- lapply(unique_groups, function(grp) which(relation_labels == grp))
names(blocks_list) <- unique_groups
# --- EXPERIMENT LOOP ---
lt <- list()
for (k in 1:10) {
  set.seed(k)
  y <- age
  X <-  DNAm
  idx_test <- sample(1:nrow(X), floor(0.2 * nrow(X)))


  # --- IPF-LASSO ---
  fit_ipf <- cvr.ipflasso(
    X = X[-idx_test, ],
    Y = y[-idx_test],
    family = "gaussian",
    type.measure = "mse",
    alpha = 1,
    blocks = blocks_list,
    pf = rep(1, length(blocks_list)),  # Use uniform penalty for now
    nfolds = 5,
    ncv = 5
  )
  coef_ipf <- fit_ipf$coeff[, fit_ipf$ind.bestlambda]
  pred_ipf <- coef_ipf[1] + X[idx_test, ] %*% coef_ipf[-1]



  nash_noinfo_dynamic_td= function(X, y, maxit = 100, tol=10e-3){


    y.fit.ebr = ebmr(X , y , maxiter = 3, ebnv_fn = ebnv.pm)
    y.fit.nash=y.fit.ebr
    elbo=c(-Inf)
    for ( k in 1:maxit){


      tt= ash( y.fit.nash$mu,  (y.fit.nash$Sigma_diag))


      elbo=c(elbo,
             y.fit.nash$elbo[length(y.fit.nash$elbo)]-tt$loglik)

      if(  (k>1) &   (elbo[k+1]-elbo[k ]< tol) ){

        break
      }
      y.fit.nash= ebmr.update(y.fit.nash,
                              mu0=tt$result$PosteriorMean,
                              maxiter = 20)
    }

    y.fit.nash$elbo=elbo
    y.fit.nash$b= tt$result$PosteriorMean
    return( y.fit.nash)
  }
 # fitmnash= nash_noinfo_dynamic_td(X=X[-idx_test, ],
   #                                y= y[-idx_test] )


  # --- Regular GLMNET models ---
  fit_lasso <- cv.glmnet(X[-idx_test, ], y[-idx_test], alpha = 1)
  fit_enet  <- cv.glmnet(X[-idx_test, ], y[-idx_test], alpha = 0.5)
  fit_ridge <- cv.glmnet(X[-idx_test, ], y[-idx_test], alpha = 0)

  pred_lasso <- predict(fit_lasso, newx = X[idx_test, ])
  pred_enet  <- predict(fit_enet,  newx = X[idx_test, ])
  pred_ridge <- predict(fit_ridge, newx = X[idx_test, ])

  # --- mr.ash ---
  fit_mrash <- mr.ash(X[-idx_test, ], y[-idx_test])
  pred_mrash <- X[idx_test, ] %*% fit_mrash$beta



  # --- RMSE ---
  rmse <- function(y, yt) sqrt(mean((y - yt)^2))
  res <- c(
    rmse(y[idx_test], pred_lasso),
    rmse(y[idx_test], pred_enet),
    rmse(y[idx_test], pred_ridge),
    rmse(y[idx_test], pred_mrash),
    rmse(y[idx_test], pred_ipf)#,
   # rmse(y[- idx_test],  X[-idx_test, ]%*%fitmnash$mu) # Placeholder for Nash if implemented later
  )

  name <- c("Lasso", "Enet", "Ridge", "MRash", "IPF")#, "Nash")

  lt[[k]] <- list(rmse = res, name = name)

  # --- Save data split ---
  write.csv(y[-idx_test], paste0("C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/GSE40279/y_train", k, ".csv"), row.names = FALSE)
  write.csv(X[-idx_test, ], paste0("C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/GSE40279/X_train", k, ".csv"), row.names = FALSE)
  write.csv(y[idx_test],  paste0("C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/GSE40279/y_test", k, ".csv"), row.names = FALSE)
  write.csv(X[idx_test, ], paste0("C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/GSE40279/X_test", k, ".csv"), row.names = FALSE)

  cat(sprintf("Fold %d RMSEs:\n", k))
  for (i in seq_along(name)) {
    cat(sprintf("%s: %.4f\n", name[i], res[i]))
  }

  save(lt, file = "C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/results_realdata/GSE40279.RData")
}
