library(MASS)
library(mvtnorm)

# ------ INPUT ------
# Q: number of sources
# p: covariate dimension
# n: target sample size
# nq: source sample size (length-Q vector)
# s: model coefficient signal
# SSR: shift-to-signal ratio (length-Q vector)
# sig_err: square root of the common noise variance
# T: simulation seed index

# ------ OUTPUT ------
# risk: excess risk of target task
# risk_tm: excess risk of each TM estimate (length-Q vector)
# risk_pm: excess risk of pooled MNI (scalar)
# risk_sgd: excess risk of each SGD estimate (length-Q vector)


TRANSFER_MNI <- function(Q, p, s, SSR, sig_err, T){
  # ------ SYNTHETIC DATA GENERATION ------
  # deterministic model coefficients (generated with a fixed seed)
  set.seed(1)
  beta   <- { u <- rnorm(p); sqrt(s) * u/sqrt(sum(u^2)) }
  delta <- { u <- rnorm(p); sqrt(SSR*s) * u/sqrt(sum(u^2)) }
  
  betas <- beta + delta
  
  # non-adjusted signal-to-noise ratio (SNR) and shift-to-signal-ratio (SSR)
  SNR <- sum(beta^2) / (sig_err^2)
  
  # target and source sample sizes
  nq <- rep(0,Q)
  nq_optim <- floor( p - 1 - sqrt( p*(p-1)/( SNR *(1-SSR))) ) 
  nq[Q] <- nq_optim
  
  n0_optim <- max( floor( p - nq_optim - sqrt( p^2/SNR + nq_optim*SSR ) ), 0) 
  n <- n0_optim
  
  # istoropic covariances
  I_p <- diag(p)
  Sigma <- I_p
  Sigma_s <- I_p
  
  # design matrices
  set.seed(2025+T)
  X <- rmvnorm(n=n, mean=rep(0,p), sigma=Sigma)
  Xs <- list()
  for (q in 1:Q){
    Xs[[q]] <- rmvnorm(n=nq[q], mean=rep(0,p), sigma=Sigma_s)
  }
  
  # response vectors
  Y = X%*%beta + sig_err*matrix(rnorm(n), ncol=1)
  Ys <- list()
  for (q in 1:Q) {
    Ys[[q]] <- Xs[[q]]%*%betas + sig_err*matrix(rnorm(nq[q]), ncol=1)
  }
  
  # ------ TARGET TASK AND TM ESTIMATES ------
  # source-only MNI
  hbeta_s <- list()
  for (q in 1:Q) {
    if (p > nq[q]) {
      hbeta_s[[q]] <- t(Xs[[q]]) %*% solve(Xs[[q]] %*% t(Xs[[q]])) %*% Ys[[q]]  
    } else {
      hbeta_s[[q]] <- solve(t(Xs[[q]])%*%Xs[[q]]) %*% t(Xs[[q]]) %*% Ys[[q]]
    }
  }
  
  # target-only MNI and the TM estimates
  if (p > n) {
    hbeta <- t(X) %*% solve(X%*%t(X)) %*% Y
    hbeta_tm <- list()
    for (q in 1:Q) {
      hbeta_tm[[q]] <- hbeta_s[[q]] + t(X)%*%solve(X%*%t(X)) %*% (Y-X%*%hbeta_s[[q]])
    }
  } else {
    hbeta <- solve(t(X)%*%X) %*% t(X) %*% Y
    hbeta_tm <- list()
    for (q in 1:Q) {
      hbeta_tm[[q]] <- hbeta_s[[q]] + solve(t(X)%*%X) %*% t(X) %*% (Y-X%*%hbeta_s[[q]])
    }
  }
  
  # ------ POOLED MNI (Song et al., 2024) -------
  # X_pool <- t(X) %*% X + Reduce("+", lapply(1:Q, function(q) t(Xs[[q]]) %*% Xs[[q]]))
  # Y_pool <- t(X) %*% Y + Reduce("+", lapply(1:Q, function(q) t(Xs[[q]]) %*% Ys[[q]]))
  # hbeta_pm <- ginv(X_pool) %*% Y_pool
  X_pool <- rbind(X, do.call(rbind, Xs))
  Y_pool <- rbind(Y,do.call(rbind, Ys))
  if (nrow(X_pool) >= p) {
    hbeta_pm <- solve(t(X_pool) %*% X_pool, t(X_pool) %*% Y_pool)
  } else {
    hbeta_pm <- t(X_pool) %*% solve(X_pool %*% t(X_pool), Y_pool)
  }
  
  # ------ Pretraining-Finetuning SGD (Wu et al., 2022) ------
  # initial step sizes for pre-training and fine-tuning respectively
  # 0.001 is adequate for identity covariance 
  gamma1 <- 0.001
  gamma2 <- 0.001
  
  # learning function (the number of epochs is 1 by default in their source code)
  pretrain_finetune <- function(X_source, Y_source, X_target, Y_target, gamma1, gamma2, 
                                source_epochs=1, target_epochs=1){
    d <- ncol(X_source) # covariate dimension
    M <- nrow(X_source) # source sample size
    N <- nrow(X_target) # target sample size
    w <- matrix(0, d, 1) # SGD estimate
    
    # update of halving the step sizes after certain steps
    L1 <- if (M > 0) ceiling(M / log2(M)) else M
    L2 <- if (N > 0) ceiling(N / log2(N)) else N
    
    # pre-training on source data
    eta <- gamma1
    total_source_steps <- M * source_epochs
    for (epoch in 1:source_epochs) {
      for (i in 1:M) {
        x <- matrix(X_source[i,], nrow=1)
        y <- Y_source[i,]
        grad <- (t(x) %*% x) %*% w - t(x) %*% y
        
        # check if we should halve eta
        current_step <- (epoch-1)*M + i
        if (current_step %% L1 == 0 && current_step > 0) {
          eta <- eta / 2
        }
        w <- w - eta * grad
      }
    }
    
    # fine-tuning on target data
    eta <- gamma2
    total_target_steps <- N * target_epochs
    for (epoch in 1:target_epochs) {
      for (i in 1:N) {
        x <- matrix(X_target[i,], nrow=1)
        y <- Y_target[i,]
        grad <- (t(x) %*% x) %*% w - t(x) %*% y
        
        current_step <- (epoch-1)*N + i
        if (current_step %% L2 == 0 && current_step > 0) {
          eta <- eta / 2
        }
        w <- w - eta * grad
      }
    }
    return(w)
  }
  
  # SGD estimates
  hbeta_sgd <- list()
  for (q in 1:Q) {
    hbeta_sgd[[q]] <- pretrain_finetune(Xs[[q]], Ys[[q]], X, Y, gamma1, gamma2)
  }
  
  # ------ EXCESS RISK CALCULATION -------
  risk <- t(hbeta - beta) %*% Sigma %*% (hbeta - beta) # target-only MNI
  risk_tm <- rep(0,Q)
  for(q in 1:Q){
    risk_tm[q] <- t(hbeta_tm[[q]] - beta) %*% Sigma %*% (hbeta_tm[[q]] - beta) # TM estimates
  }
  risk_pm <- t(hbeta_pm - beta) %*% Sigma %*% (hbeta_pm - beta) # pooled MNI
  risk_sgd <- rep(0,Q)
  for(q in 1:Q){
    risk_sgd[q] <- t(hbeta_sgd[[q]] - beta) %*% Sigma %*% (hbeta_sgd[[q]] - beta) # SGD estimates
  }
  
  # ------- OUTPUT RETURN -------
  result <- matrix(0, nrow=6, ncol=Q)
  for(q in 1:Q){
    result[, q] = c(risk, risk_tm[q], risk_pm, risk_sgd[q], n0_optim, nq_optim)
  }
  nrows <<- nrow(result)
  return(result)
}

# ------ SIMULATION SETUP AND IMPLEMENTATION -------
Q <- 1
p_set <- seq(from=300, to=1000, by=100)
s <- 10
SSR <- rep(0.1,Q)
sig_err <- sqrt(1)

system.time({
  TT <- 50 # number of independent simulation runs
  results <- NULL
  for (T in 1:TT) { 
    result <- sapply(p_set, TRANSFER_MNI, Q=Q, s=s, SSR=SSR, sig_err=sig_err, T=T)
    results <- rbind(results, result)
  }
})

# ------ RESULT COLLECTION -------
# mean excess risks
risk_mean <- colMeans(results[seq(1, TT*nrows*1, nrows*1), ])
risk_tm_mean <- matrix(0, nrow=Q, ncol=length(p_set))
for(q in 1:Q){
  risk_tm_mean[q, ] <- colMeans(results[seq((q-1)*nrows + 2, TT*nrows*Q, nrows*Q), ])
}
risk_pm_mean <- colMeans(results[seq(3, TT*nrows*1, nrows*1), ])
risk_sgd_mean <- matrix(0, nrow=Q, ncol=length(p_set))
for(q in 1:Q){
  risk_sgd_mean[q, ] <- colMeans(results[seq((q-1)*nrows + 4, TT*nrows*Q, nrows*Q), ])
}
n0_optim <- colMeans(results[seq(5, TT*nrows*Q, nrows*Q), ])
nq_optim <- colMeans(results[seq(6, TT*nrows*Q, nrows*Q), ])


# standard deviation of excess risks
risk_sd <- apply(results[seq(1, TT*nrows*1, nrows*1), ], 2, sd)
risk_tm_sd <- matrix(0, nrow=Q, ncol=length(p_set))
for(q in 1:Q){
  risk_tm_sd[q, ] <- apply(results[seq((q-1)*nrows + 2, TT*nrows*Q, nrows*Q), ], 2, sd)
}
risk_pm_sd <- apply(results[seq(3, TT*nrows*1, nrows*1), ], 2, sd)
risk_sgd_sd <- matrix(0, nrow=Q, ncol=length(p_set))
for(q in 1:Q){
  risk_sgd_sd[q, ] <- apply(results[seq((q-1)*nrows + 4, TT*nrows*Q, nrows*Q), ], 2, sd)
}


# result dataframe for risk plotting
df_plot <- data.frame(p = p_set)

df_plot$risk_baseline_MEAN <- risk_mean
for(q in 1:Q){
  df_plot[[paste0("risk_TM", q, "_MEAN")]] <- risk_tm_mean[q, ]
}
df_plot$risk_PM_MEAN <- risk_pm_mean
for(q in 1:Q){
  df_plot[[paste0("risk_SGD", q, "_MEAN")]] <- risk_sgd_mean[q, ]
}

df_plot$risk_baseline_SD <- risk_sd
for(q in 1:Q){
  df_plot[[paste0("risk_TM", q, "_SD")]] <- risk_tm_sd[q, ]
}
df_plot$risk_PM_SD <- risk_pm_sd
for(q in 1:Q){
  df_plot[[paste0("risk_SGD", q, "_SD")]] <- risk_sgd_sd[q, ]
}
df_plot$n0_optim <- n0_optim
df_plot$nq_optim <- nq_optim

saveRDS(df_plot, "fig4a.rds")
# df_plot <- readRDS("fig4a.rds")

pdf("fig4a.pdf", width = 10, height = 8)
par(family = "serif")
layout(matrix(1), widths = 10, heights = 8)
par(mar = c(3.2, 3.4, 0.2, 0.2), font.lab = 2, mgp = c(2, 0.5, 0),
    cex.axis = 2,   # axis tick size
    cex.main = 0.5,   # main title size
    cex.sub = 0.5)    # subtitle size

# Plot for excess risks
plot(df_plot$p, df_plot$risk_baseline_MEAN, type = "n", ylim = c(4, 12), 
     cex.lab=1.8,
     ylab = "Excess Risk", xlab = "p", xaxt='n', yaxt='n', 
     xlim = range(df_plot$p), 
     panel.first = rect(par("usr")[1], par("usr")[3], par("usr")[2], par("usr")[4], col = "white", border = NA))
axis(1, at = p_set, labels = TRUE, font = 2, cex.axis=1.5)  # x-axis ticks
axis(2, at = seq(4, 12, by = 2), labels = TRUE, font = 2, las = 2, cex.axis=2) # y-axis ticks

# title(
#   main = "TITLE",
#   font.main = 2,      # Bold font (1 = plain, 2 = bold, 3 = italic, 4 = bold italic)
#   line = 0.5          # Adjust line position
# )

colors <- c(
  adjustcolor("black", alpha.f = 0.75),
  #rgb(54/255, 69/255, 79/255, alpha = 0.75),    # "#36454F" Charcoal
  
  rgb(220/255, 20/255, 60/255, alpha = 0.75),   # "#DC143C" Crimson
  # rgb(255/255, 165/255, 0/255, alpha = 0.75),    # "#FFA500" Orange
  # rgb(255/255, 215/255, 0/255, alpha = 0.75),    # "#FFD700" Gold
  
  # rgb(34/255, 139/255, 34/255, alpha = 0.75),    # "#228B22" Forest Green
  
  rgb(65/255, 105/255, 225/255, alpha = 0.75),  # "#4169E1" Royal Blue
  # rgb(0/255, 0/255, 128/255, alpha = 0.75),      # "#000080" Navy
  
  rgb(255/255, 105/255, 180/255, alpha = 0.75)  # "#FF69B4" Hot Pink
  # rgb(128/255, 0/255, 128/255, alpha = 0.75)    # "#800080" Purple
  # adjustcolor("lightpink1", alpha.f = 0.75)
  
  # rgb(65/255, 105/255, 225/255, alpha = 0.75),  # "#4169E1" Royal Blue
  # 
  # rgb(255/255, 215/255, 0/255, alpha = 0.75),    # "#FFD700" Gold
  # rgb(255/255, 99/255, 71/255, alpha = 0.75),     # "#FF6347" Tomato
  # rgb(138/255, 43/255, 226/255, alpha = 0.75),  # "#8A2BE2" Blue Violet
  # rgb(0/255, 128/255, 128/255, alpha = 0.75),   # "#008080" Teal
  # rgb(30/255, 144/255, 255/255, alpha = 0.75),  # "#1E90FF" Dodger Blue
  # rgb(75/255, 0/255, 130/255, alpha = 0.75),     # "#4B0082" Indigo
)

# Define plot customization parameters
pch_val <- c(
  15, # square
  17, # triangle
  18, # rhombus
  20  # dot
)
line_width <- 5
lty_val <- c("solid", "dashed")
dot_size <- 4 # 1.5 for square, 2 for triangle and rhombus


grid(col = "lightgray", lty = "solid")
points(df_plot$p, df_plot$risk_baseline_MEAN, pch = pch_val[1], col = colors[1], cex = 2)
lines(df_plot$p, df_plot$risk_baseline_MEAN, lty = "dashed", col = colors[1], lwd = line_width)
for (q in 1:Q) {
  points(df_plot$p, df_plot[[paste0("risk_TM", q, "_MEAN")]], pch = pch_val[4], col = colors[1+q], cex = dot_size)
  lines(df_plot$p, df_plot[[paste0("risk_TM", q, "_MEAN")]], lty = "solid", col = colors[1+q], lwd = line_width)
}
points(df_plot$p, df_plot$risk_PM_MEAN, pch = pch_val[2], col = colors[1+Q+1], cex = 3)
lines(df_plot$p, df_plot$risk_PM_MEAN, lty = "dashed", col = colors[1+Q+1], lwd = line_width)
for (q in 1:Q) {
  points(df_plot$p, df_plot[[paste0("risk_SGD", q, "_MEAN")]], pch = pch_val[3], col = colors[1+Q+1+q], cex = 3)
  lines(df_plot$p, df_plot[[paste0("risk_SGD", q, "_MEAN")]], lty = "dashed", col = colors[1+Q+1+q], lwd = line_width)
}

# legend(
#   "topright",  # Position in the top right corner
#   legend = c("Baseline", paste0("TM", 1:Q), "PM", paste0("SGD", 1:Q)),
#   pch = c(pch_val[1], rep(pch_val[4], Q), pch_val[2], rep(pch_val[3], Q)), # symbols
#   lty = c("dashed", rep("solid", Q), "dashed", rep("dashed", Q)),  # Line types
#   col = colors,  # Colors
#   cex = 1.8,  # Legend size multiplier
#   ncol = 3,  # One column
#   bty = "n",  # No box around the legend
#   lwd = line_width,  # Line width for legend items
#   text.font = 2,  # Bold font for text
#   x.intersp = 0.5,  # Adjust space between symbols and text
#   y.intersp = 1.2   # Adjust vertical spacing between rows
# )

dev.off()

