library(genlasso)
library(ggplot2)
library(igraph)
library(Matrix)
library(gridExtra)
library(glmnet)
library(genlasso)
library(Matrix)
library(FNN)
library(ebmr.alpha)
library(mr.ash.alpha)
library(ashr)
rmse= function(y,yt){
  sqrt(mean ((y-yt)^2))
}

colScale <- function(x, center = TRUE, scale = TRUE, add_attr = TRUE, rows = NULL, cols = NULL) {
  if (!is.null(rows) || !is.null(cols)) {
    x <- x[rows %||% seq_len(nrow(x)), cols %||% seq_len(ncol(x)), drop = FALSE]
  }

  cm <- if (center) colMeans(x, na.rm = TRUE) else rep(0, ncol(x))
  csd <- if (scale) matrixStats::colSds(x, center = cm) else rep(1, length(cm))
  csd[csd == 0] <- 1  # Prevent division by zero

  x <- sweep(x, 2, cm, "-")
  x <- sweep(x, 2, csd, "/")

  if (add_attr) {
    if (center) attr(x, "scaled:center") <- cm
    if (scale) attr(x, "scaled:scale") <- csd
    n <- nrow(x)
    d <- (n - 1) * csd^2 / csd^2
    attr(x, "d") <- d
  }
  x
}

nash_noinfo_dynamic_td= function(X, y, maxit = 100, tol=10e-3){


  y.fit.ebr = ebmr(X , y , maxiter = 20, ebnv_fn = ebnv.pm)
  y.fit.nash=y.fit.ebr
  elbo=c(-Inf)
  for ( k in 1:maxit){


    tt= ash( y.fit.nash$mu,  (y.fit.nash$Sigma_diag))


    elbo=c(elbo,
           y.fit.nash$elbo[length(y.fit.nash$elbo)]-tt$loglik)

    if(  (k>1) &   (elbo[k+1]-elbo[k ]< tol) ){

      break
    }
    y.fit.nash= ebmr.update(y.fit.nash,
                            mu0=tt$result$PosteriorMean,
                            maxiter = 20)
  }

  y.fit.nash$elbo=elbo
  y.fit.nash$b= tt$result$PosteriorMean
  return( y.fit.nash)
}

# Visualize prediction
set.seed(123)

load("C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/script/real_data_analysis/LIBD_sample12.RData")
tt_sub=  tt_sub=as.matrix(count_sub)

lt=list()

for ( o in 1:10){
set.seed(o)

 idx=  sample(1:ncol(tt_sub), size=1)


 y= tt_sub[,idx]
 X= tt_sub[,-idx]

 idx_test= sample(size= floor(0.2* length(y)), 1:length(y))
 y_test= y[idx_test]
 y_train= y[-idx_test]
 X_test= X[idx_test,]
 X_train= X[-idx_test,]

 fit <- cv.glmnet(X[-idx_test, ], y[-idx_test], alpha = 0)
 predictions_lasso <- predict(fit, newx = X[idx_test, ])

 fit <- cv.glmnet(X[-idx_test, ], y[-idx_test], alpha = 0.5)
 predictions_enet <- predict(fit, newx = X[idx_test, ])

 fit <- cv.glmnet(X[-idx_test, ], y[-idx_test], alpha = 1)
 predictions_ridge <- predict(fit, newx = X[idx_test, ])


 y <- tt_sub[, idx]               # Expression at spot to predict
 X <- tt_sub[, -idx]             # Expression at other spots
 coords <- xy_coords[-idx, ]     # Locations of other spots
 target_coord <- xy_coords[idx, ]  # Location of held-out spot

 # Build kNN graph over remaining spots
 k <- 6
 nn <- get.knn(coords, k = k)

 # Build edge list
 edges <- do.call(rbind, lapply(1:nrow(coords), function(i) {
   cbind(i, nn$nn.index[i, ])
 }))
 edges <- unique(t(apply(edges, 1, sort)))  # Remove duplicates

 # Convert to igraph object
 g <- graph_from_edgelist(edges, directed = FALSE)

 # Fit fused lasso (on full set of predictors, spatial graph penalty)
 fit <- fusedlasso(y = y[-idx_test ], X = X[-idx_test,], graph = g)


 lambda_index <- floor(length(fit$lambda) / 2)
 predictions_glasso  <-X[ idx_test, ] %*% coef(fit)$beta[, lambda_index ]
 # Plot solution path



 fitmr= mr.ash(X[-idx_test, ],
               y[-idx_test] )
 predictions_mr <- X[idx_test, ]%*%fitmr$beta



 #fit nnet
 res= c(
   rmse(y[idx_test],predictions_lasso),
   rmse(y[idx_test],predictions_enet),
   rmse(y[idx_test],predictions_ridge),
   rmse(y[idx_test],predictions_mr),

   rmse(predictions_glasso,y[idx_test])  )

 name = c( "Lasso",
           "Enet",
           "Ridge",
           "MRash",
           "Glasso" )


 lt[[o]]= list(rmse=res,
               name=name)

 write.csv(xy_coords[-idx,] ,
           paste0("C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/spaRNA_seq/infocov_",o,".csv"),
           row.names = FALSE)


 write.csv(y[-idx_test], paste0("C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/spaRNA_seq/y_train",o,".csv"), row.names = FALSE)
 write.csv(X[-idx_test,-1], paste0("C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/spaRNA_seq/X_train",o,".csv"), row.names = FALSE)
 write.csv(y[idx_test], paste0("C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/spaRNA_seq/y_test",o,".csv"), row.names = FALSE)
 write.csv(X[idx_test,-1], paste0("C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/data_split/spaRNA_seq/X_test",o,".csv"), row.names = FALSE)
 save(lt, file="C:/Document/Serieux/Travail/Data_analysis_and_papers/nash_experiement/results_realdata/spaRNA_seq.RData")


}
