library(MASS)
library(mvtnorm)

# ------ INPUT ------
# Q: number of sources
# p: covariate dimension
# n: target sample size
# nq: source sample size (length-Q vector)
# s: model coefficient signal
# SSR: shift-to-signal ratio (length-Q vector)
# sig_err: square root of the common noise variance
# K: number of cross-validation folds
# T: simulation seed index

# ------ OUTPUT ------
# risk: excess risk of target task
# risk_tm: excess risk of each TM estimate (length-Q vector)
# risk_wtm: excess risk of WTM estimate (scalar)
# risk_pm: excess risk of pooled MNI (length-Q vector)


TRANSFER_MNI <- function(Q, p, n, nq, s, SSR, sig_err, K, T){
  # ------ SYNTHETIC DATA GENERATION ------
  # deterministic model coefficients (generated with a fixed seed)
  set.seed(1)
  beta   <- { u <- rnorm(p); sqrt(s) * u/sqrt(sum(u^2)) }
  deltas <- list()
  betas <- list()
  for (q in 1:Q) {
    deltas[[q]] <- { u <- rnorm(p); sqrt(SSR[q]*s) * u/sqrt(sum(u^2)) }
    betas[[q]] <- beta + deltas[[q]]
  }

  # homogeneous target and source covariances
  # Bartlett et al. (2020) (with beta=3/2)
  eigenval <- 15 * (1:p)^(-1) * log(((1:p) + 1) * exp(1)/2)^(-1.5)
  Sigma <- diag(eigenval)
  Sigma_s <- diag(eigenval)

  # design matrices
  set.seed(05232025+T)
  X <- rmvnorm(n=n, mean=rep(0,p), sigma=Sigma)
  Xs <- list()
  for (q in 1:Q){
    Xs[[q]] <- rmvnorm(n=nq[q], mean=rep(0,p), sigma=Sigma_s)
  }
  
  # response vectors
  Y = X%*%beta + sig_err*matrix(rnorm(n), ncol=1)
  Ys <- list()
  for (q in 1:Q) {
    Ys[[q]] <- Xs[[q]]%*%betas[[q]] + sig_err*matrix(rnorm(nq[q]), ncol=1)
  }
  
  # ------ TARGET TASK AND TM ESTIMATES ------
  # source-only MNI
  hbeta_s <- list()
  for (q in 1:Q) {
    if (p > nq[q]) {
      hbeta_s[[q]] <- t(Xs[[q]]) %*% solve(Xs[[q]] %*% t(Xs[[q]])) %*% Ys[[q]]  
    } else {
      hbeta_s[[q]] <- solve(t(Xs[[q]])%*%Xs[[q]]) %*% t(Xs[[q]]) %*% Ys[[q]]
    }
  }
  
  # target-only MNI and the TM estimates
  if (p > n) {
    hbeta <- t(X) %*% solve(X%*%t(X)) %*% Y
    hbeta_tm <- list()
    for (q in 1:Q) {
      hbeta_tm[[q]] <- hbeta_s[[q]] + t(X)%*%solve(X%*%t(X)) %*% (Y-X%*%hbeta_s[[q]])
    }
  } else {
    hbeta <- solve(t(X)%*%X) %*% t(X) %*% Y
    hbeta_tm <- list()
    for (q in 1:Q) {
      hbeta_tm[[q]] <- hbeta_s[[q]] + solve(t(X)%*%X) %*% t(X) %*% (Y-X%*%hbeta_s[[q]])
    }
  }
  
  # ------ K-FOLD CROSS_VALIDATION FOR INFORMATIVE SOURCE DETECTION -------
  cv_losses <- rep(0, K)
  cv_losses_tm  <- matrix(0, nrow=Q, ncol=K, byrow=TRUE)
  
  fold_size = floor(n/K)
  
  for (q in 1:Q) {
    for (k in 1:K) {
      test_ind = ((k-1) * fold_size + 1):(k * fold_size)
      train_ind = setdiff(1:n, test_ind)
      
      # train-test split of target data
      X_train = X[train_ind, ]
      Y_train = Y[train_ind, ]
      X_test = X[test_ind, ]
      Y_test = Y[test_ind, ]
      
      # training with left-out target folds
      if (p > nrow(X_train)) {
        hcv_beta = t(X_train) %*% solve(X_train%*%t(X_train)) %*% Y_train
        hcv_beta_tm = hbeta_s[[q]] + t(X_train) %*% solve(X_train%*%t(X_train)) %*% (Y_train-X_train%*%hbeta_s[[q]])
      } else {
        hcv_beta = solve(t(X_train)%*%X_train) %*% t(X_train) %*% Y_train
        hcv_beta_tm = hbeta_s[[q]] + solve(t(X_train)%*%X_train) %*% t(X_train) %*% (Y_train-X_train%*%hbeta_s[[q]]) 
      }
      
      # l-2 square loss on each target fold
      cv_losses[k] = sum((Y_test-X_test%*%hcv_beta)^2) / nrow(X_test)
      cv_losses_tm[q, k] = sum((Y_test-X_test%*%hcv_beta_tm)^2) / nrow(X_test)
    }
  }
  # final cv loss of the target and each transfer task
  loss = mean(cv_losses)
  loss_tm = rowMeans(cv_losses_tm)
  
  # informative source detection via CV 
  sig_loss = sqrt(sum((cv_losses - loss)^2 / (K-1)))
  eps0 = 1/2
  infosource_ind <- which(loss_tm-loss < eps0*max(sig_loss, 0.01))
  
  # ------ WEIGHTED-INFORMATIVE TRANSFER MNI -------
  if (length(infosource_ind) == 0){
    # no informative source detected -> take target task
    hbeta_wtm <- hbeta
  } else{
    # informative source is detected 
    weights <- 1/loss_tm[infosource_ind]
    weights <- weights/sum(weights)
    
    # informative TM and their weighted average (WTM) 
    hbeta_tm_info <- lapply(infosource_ind, function(q) hbeta_tm[[q]])
    hbeta_wtm <- Reduce(`+`, Map(`*`, hbeta_tm_info, weights))
  }
  
  # ------ POOLED MNI (Song et al., 2024) -------
  X_pool <- list()
  Y_pool <- list()
  hbeta_pm <- list()
  
  for (q in 1:Q){
    X_pool[[q]] <- rbind(X, Xs[[q]])
    Y_pool[[q]] <- rbind(Y, Ys[[q]])
    
    if (nrow(X_pool[[q]]) >= p) {
      hbeta_pm[[q]] <- solve(t(X_pool[[q]]) %*% X_pool[[q]], t(X_pool[[q]]) %*% Y_pool[[q]])
    } else {
      hbeta_pm[[q]] <- t(X_pool[[q]]) %*% solve(X_pool[[q]] %*% t(X_pool[[q]]), Y_pool[[q]])
    }
  }
  
  
  # ------ EXCESS RISK CALCULATION -------
  risk <- t(hbeta - beta) %*% Sigma %*% (hbeta - beta) # target-only MNI
  risk_tm <- rep(0,Q)
  for(q in 1:Q){
    risk_tm[q] <- t(hbeta_tm[[q]] - beta) %*% Sigma %*% (hbeta_tm[[q]] - beta) # TM estimates
  }
  risk_wtm <- t(hbeta_wtm - beta) %*% Sigma %*% (hbeta_wtm - beta) # WTM estimate
  risk_pm <- rep(0,Q)
  for(q in 1:Q){
    risk_pm[q] <- t(hbeta_pm[[q]] - beta) %*% Sigma %*% (hbeta_pm[[q]] - beta) # pooled MNI
  }
  
  
  
  # ------- OUTPUT RETURN -------
  result <- matrix(0, nrow=4, ncol=Q)
  for(q in 1:Q){
    result[, q] = c(risk, risk_tm[q], risk_wtm, risk_pm[q])
  }
  nrows <<- nrow(result)
  return(result)
}

# ------ SIMULATION SETUP AND IMPLEMENTATION -------
Q <- 3
p_set <- seq(from=300, to=1000, by=100)
n <- 25
nq <- rep(75, Q)
s <- 500 
sig_err <- sqrt(1)
SSR <- c(0, 0.3, 0.6)
K <- 5

system.time({
  TT <- 50 # number of independent simulation runs
  results <- NULL
  for (T in 1:TT) { 
    result <- sapply(p_set, TRANSFER_MNI, Q=Q, n=n, nq=nq, s=s, SSR=SSR, sig_err=sig_err, K=K, T=T)
    results <- rbind(results, result)
  }
})

# ------ RESULT COLLECTION -------
# mean excess risks
risk_mean <- colMeans(results[seq(1, TT*nrows*1, nrows*1), ])
risk_tm_mean <- matrix(0, nrow=Q, ncol=length(p_set))
for(q in 1:Q){
  risk_tm_mean[q, ] <- colMeans(results[seq((q-1)*nrows + 2, TT*nrows*Q, nrows*Q), ])
}
risk_wtm_mean <- colMeans(results[seq(3, TT*nrows*1, nrows*1), ])
risk_pm_mean <- matrix(0, nrow=Q, ncol=length(p_set))
for(q in 1:Q){
  risk_pm_mean[q, ] <- colMeans(results[seq((q-1)*nrows + 4, TT*nrows*Q, nrows*Q), ])
}

# standard deviation of excess risks
risk_sd <- apply(results[seq(1, TT*nrows*1, nrows*1), ], 2, sd)
risk_tm_sd <- matrix(0, nrow=Q, ncol=length(p_set))
for(q in 1:Q){
  risk_tm_sd[q, ] <- apply(results[seq((q-1)*nrows + 2, TT*nrows*Q, nrows*Q), ], 2, sd)
}
risk_wtm_sd <- apply(results[seq(3, TT*nrows*1, nrows*1), ], 2, sd)
risk_pm_sd <- matrix(0, nrow=Q, ncol=length(p_set))
for(q in 1:Q){
  risk_pm_sd[q, ] <- apply(results[seq((q-1)*nrows + 4, TT*nrows*Q, nrows*Q), ], 2, sd)
}

# result dataframe for risk plotting
df_plot <- data.frame(p = p_set)

df_plot$risk_baseline_MEAN <- risk_mean
for(q in 1:Q){
  df_plot[[paste0("risk_TM", q, "_MEAN")]] <- risk_tm_mean[q, ]
}
df_plot$risk_WTM_MEAN <- risk_wtm_mean
for(q in 1:Q){
  df_plot[[paste0("risk_PM", q, "_MEAN")]] <- risk_pm_mean[q, ]
}
df_plot$risk_baseline_SD <- risk_sd
for(q in 1:Q){
  df_plot[[paste0("risk_TM", q, "_SD")]] <- risk_tm_sd[q, ]
}
df_plot$risk_WTM_SD <- risk_wtm_sd
for(q in 1:Q){
  df_plot[[paste0("risk_PM", q, "_SD")]] <- risk_pm_sd[q, ]
}

saveRDS(df_plot, "fig3a.rds")
# df_plot <- readRDS("fig3a.rds")

pdf("fig3a.pdf", width = 10, height = 8)
par(family = "serif")
#layout(matrix(1), widths = 10, heights = 8)
par(mar = c(3.2, 3.4, 0.2, 0.2), font.lab = 2, mgp = c(2, 0.5, 0),
    cex.axis = 2,   # axis tick size
    cex.main = 0.5,   # main title size
    cex.sub = 0.5)    # subtitle size

# Plot for excess risks
plot(df_plot$p, df_plot$risk_baseline_MEAN, type = "n", ylim = c(2, 12), 
     cex.lab=1.8,
     ylab = "Excess Risk", xlab = "p", xaxt='n', yaxt='n', 
     xlim = range(df_plot$p), 
     panel.first = rect(par("usr")[1], par("usr")[3], par("usr")[2], par("usr")[4], col = "white", border = NA))
axis(1, at = p_set, labels = TRUE, font = 2, cex.axis=1.5)  # x-axis ticks
axis(2, at = seq(2, 12, by = 2), labels = TRUE, font = 2, las = 2, cex.axis=2) # y-axis ticks

# title(
#   main = "TITLE",
#   font.main = 2,      # Bold font (1 = plain, 2 = bold, 3 = italic, 4 = bold italic)
#   line = 0.5          # Adjust line position
# )

colors <- c(
  adjustcolor("black", alpha.f = 0.75),
  #rgb(54/255, 69/255, 79/255, alpha = 0.75),    # "#36454F" Charcoal
  
  rgb(220/255, 20/255, 60/255, alpha = 0.75),   # "#DC143C" Crimson
  rgb(255/255, 165/255, 0/255, alpha = 0.75),    # "#FFA500" Orange
  rgb(255/255, 215/255, 0/255, alpha = 0.75),    # "#FFD700" Gold
  
  rgb(34/255, 139/255, 34/255, alpha = 0.75),    # "#228B22" Forest Green
  
  rgb(65/255, 105/255, 225/255, alpha = 0.75),  # "#4169E1" Royal Blue
  rgb(30/255, 144/255, 255/255, alpha = 0.75),  # "#1E90FF" Dodger Blue
  rgb(75/255, 0/255, 130/255, alpha = 0.75)     # "#4B0082" Indigo
  # rgb(0/255, 0/255, 128/255, alpha = 0.75),      # "#000080" Navy
  
  # rgb(255/255, 105/255, 180/255, alpha = 0.75),  # "#FF69B4" Hot Pink
  # rgb(128/255, 0/255, 128/255, alpha = 0.75),    # "#800080" Purple
  # adjustcolor("lightpink1", alpha.f = 0.75)
  
  # rgb(65/255, 105/255, 225/255, alpha = 0.75),  # "#4169E1" Royal Blue
  # 
  # rgb(255/255, 215/255, 0/255, alpha = 0.75),    # "#FFD700" Gold
  # rgb(255/255, 99/255, 71/255, alpha = 0.75),     # "#FF6347" Tomato
  # rgb(138/255, 43/255, 226/255, alpha = 0.75),  # "#8A2BE2" Blue Violet
  # rgb(0/255, 128/255, 128/255, alpha = 0.75),   # "#008080" Teal
  # rgb(30/255, 144/255, 255/255, alpha = 0.75),  # "#1E90FF" Dodger Blue
  # rgb(75/255, 0/255, 130/255, alpha = 0.75),     # "#4B0082" Indigo
)

# Define plot customization parameters
pch_val <- c(
  15, # square
  17, # triangle
  18, # rhombus
  20  # dot
)
line_width <- 5
lty_val <- c("solid", "dashed")
dot_size <- 4 # 1.5 for square, 2 for triangle and rhombus

grid(col = "lightgray", lty = "solid")
points(df_plot$p, df_plot$risk_baseline_MEAN, pch = pch_val[1], col = colors[1], cex = 2)
lines(df_plot$p, df_plot$risk_baseline_MEAN, lty = "dashed", col = colors[1], lwd = line_width)
for (q in 1:Q) {
  points(df_plot$p, df_plot[[paste0("risk_TM", q, "_MEAN")]], pch = pch_val[4], col = colors[1+q], cex = dot_size)
  lines(df_plot$p, df_plot[[paste0("risk_TM", q, "_MEAN")]], lty = "solid", col = colors[1+q], lwd = line_width)
}
points(df_plot$p, df_plot$risk_WTM_MEAN, pch = pch_val[4], col = colors[1+Q+1], cex = dot_size)
lines(df_plot$p, df_plot$risk_WTM_MEAN, lty = "solid", col = colors[1+Q+1], lwd = line_width)

for (q in 1:Q) {
  points(df_plot$p, df_plot[[paste0("risk_PM", q, "_MEAN")]], pch = pch_val[3], col = colors[1+Q+1+q], cex = 3)
  lines(df_plot$p, df_plot[[paste0("risk_PM", q, "_MEAN")]], lty = "dashed", col = colors[1+Q+1+q], lwd = line_width)
}


# legend(
#   "bottomleft",  # Position in the top right corner
#   legend = c("Baseline", paste0("TM", 1:Q), "WTM", "PM", paste0("SGD", 1:Q)),
#   pch = c(pch_val[1], rep(pch_val[4], Q), pch_val[4], pch_val[2], rep(pch_val[3], Q)), # symbols
#   lty = c("dashed", rep("solid", Q), "solid", "dashed", rep("dashed", Q)),  # Line types
#   col = colors,  # Colors
#   cex = 1.8,  # Legend size multiplier
#   ncol = 3,  # One column
#   bty = "n",  # No box around the legend
#   lwd = line_width,  # Line width for legend items
#   text.font = 2,  # Bold font for text
#   x.intersp = 0.5,  # Adjust space between symbols and text
#   y.intersp = 1.2   # Adjust vertical spacing between rows
# )

dev.off()

