################
### Packages ###
################

library(glmnet)
library(progress)
library(IsingSampler)
library(IsingFit)
library(qgraph)
library(ggplot2)
library(pROC)
library(reshape2)
library(dplyr)
library(plotly)
library(ncvreg)
library(Matrix)
library(patchwork)
library(igraph)


#################
### Functions ###
#################

trans_loglasso <- function(x_A_list, x_0, lambda_j = NULL, lambda_delta = NULL, lambda_scad = NULL, fold = 2) {
  p <- ncol(x_0)
  n_0 <- nrow(x_0)
  S <- length(x_A_list)
  
  folds <- sample(rep(1:fold, length.out = n_0))

  primary_losses <- numeric(fold)
  combined_losses <- matrix(0, nrow = S, ncol = fold)
  
  cat("=== Step 2: Source Detection via Pseudolikelihood Loss Comparison ===\n")
  
  for (r in 1:fold) {
    train_idx <- which(folds != r)
    test_idx <- which(folds == r)
    
    x_train <- x_0[train_idx, ]
    x_test <- x_0[test_idx, ]
    
    theta_primary <- naive_loglasso(x_train, lambda_delta)
    primary_losses[r] <- pseudolikelihood_loss(x_test, theta_primary)
    
    for (k in 1:S) {
      x_A <- x_A_list[[k]]
      x_combined <- rbind(x_train, x_A)
      
      theta_combined <- naive_loglasso(x_combined, lambda_j)
      combined_losses[k, r] <- pseudolikelihood_loss(x_test, theta_combined)
    }
  }
  
  L0_bar <- mean(primary_losses)
  sigma_hat <- sqrt(sum((primary_losses - L0_bar)^2) / (fold-1))
  
  loss_diffs <- rowMeans(combined_losses) - L0_bar
  for (k in 1:S) {
    cat(sprintf("Auxiliary %d: Avg Loss Diff = %.6f\n", k, loss_diffs[k]))
  }
  
  adaptive_threshold <- sigma_hat / 2
  informative_indices <- which(loss_diffs < adaptive_threshold)
  
  cat(sprintf("Selected %d informative auxiliary datasets (threshold = %.6f): %s\n", 
              length(informative_indices), adaptive_threshold, paste(informative_indices, collapse = ", ")))
  
  x_A_selected <- x_A_list[informative_indices]
  
  cat("=== Step 4: Oracle Trans-LogLasso on Selected Auxiliaries ===\n")
  result <- oracle_trans_loglasso(x_A_selected, x_0, lambda_j, lambda_delta, lambda_scad)
  
  return(list(
    beta_hat = result$beta_hat,
    informative_set = informative_indices,
    loss_diffs = loss_diffs,
    sigma_hat = sigma_hat,
    adaptive_threshold = adaptive_threshold
  ))
}




trans_loglasso_2 <- function(x_A_list, x_0, lambda_j = NULL, lambda_delta = NULL, lambda_scad = NULL, fold = 2) {
  p <- ncol(x_0)
  n_0 <- nrow(x_0)
  S <- length(x_A_list)
  folds <- sample(rep(1:fold, length.out = n_0))
  primary_losses <- numeric(fold)
  combined_losses <- matrix(0, nrow = S, ncol = fold)
  
  cat("=== Step 2: Source Detection via Pseudolikelihood Loss Comparison ===\n")
  
  for (r in 1:fold) {
    train_idx <- which(folds != r)
    test_idx <- which(folds == r)
    
    x_train <- x_0[train_idx, ]
    x_test <- x_0[test_idx, ]
    
    theta_primary <- naive_loglasso(x_train, lambda_delta)
    primary_losses[r] <- pseudolikelihood_loss(x_test, theta_primary)
    
    for (k in 1:S) {
      x_A <- x_A_list[[k]]
      x_combined <- rbind(x_train, x_A)
      
      theta_combined <- naive_loglasso(x_combined, lambda_j)
      combined_losses[k, r] <- pseudolikelihood_loss(x_test, theta_combined)
    }
  }

  L0_bar <- mean(primary_losses)
  sigma_hat <- sqrt(sum((primary_losses - L0_bar)^2) / (fold-1))

  loss_diffs <- rowMeans(combined_losses) - L0_bar
  for (k in 1:S) {
    cat(sprintf("Auxiliary %d: Avg Loss Diff = %.6f\n", k, loss_diffs[k]))
  }

  adaptive_threshold <- sigma_hat * 2
  informative_indices <- which(loss_diffs < adaptive_threshold)
  
  cat(sprintf("Selected %d informative auxiliary datasets (threshold = %.6f): %s\n", 
              length(informative_indices), adaptive_threshold, paste(informative_indices, collapse = ", ")))
  
  x_A_selected <- x_A_list[informative_indices]
  
  cat("=== Step 4: Oracle Trans-LogLasso on Selected Auxiliaries ===\n")
  result <- oracle_trans_loglasso(x_A_selected, x_0, lambda_j, lambda_delta, lambda_scad)
  
  return(list(
    beta_hat = result$beta_hat,
    informative_set = informative_indices,
    loss_diffs = loss_diffs,
    sigma_hat = sigma_hat,
    adaptive_threshold = adaptive_threshold
  ))
}




pseudolikelihood_loss <- function(x, theta) {
  
  n <- nrow(x)
  p <- ncol(x)
  loss <- 0
  
  for (j in 1:p) {
    x_minus_j <- x[, -j]
    theta_j <- theta[j, -j]
    
    linear_predictor <- x_minus_j %*% theta_j
    
    log_prob <- x[, j] * linear_predictor - log(1 + exp(linear_predictor))
    
    loss <- loss + sum(log_prob)
  }
  return(-loss / (n))
}



scad_penalty_fun <- function(beta, lambda, a = 3.7) {
  abs_beta <- abs(beta)
  penalty <- 0
  if (abs_beta <= lambda) {
    penalty <- lambda * abs_beta
  } else if (abs_beta <= a * lambda) {
    penalty <- -(abs_beta^2 - 2 * a * lambda * abs_beta + lambda^2) / (2 * (a - 1))
  } else {
    penalty <- (a + 1) * lambda^2 / 2
  }
  return(penalty)
}



objective_fun_step2 <- function(delta, theta_A_j, X_j, y_j, lambda_delta, lambda_scad) {
  
  beta_final <- theta_A_j + delta
  
  eta <- X_j %*% beta_final
  eta <- pmin(pmax(eta, -10), 10) 
  loss <- -sum(y_j * eta - log(1 + exp(eta)))
  
  penalty_lasso <- lambda_delta * sum(abs(delta))
  
  penalty_scad <- sum(sapply(beta_final, function(b) scad_penalty_fun(b, lambda_scad)))
  
  return(loss + penalty_lasso + penalty_scad)
}



generate_theta_true <- function(p, thres = 0.7, graph_type = "random", block_size = 10, num_spokes = 5) {
  theta <- matrix(0, p, p)
  
  if (graph_type == "random") {
    theta_raw <- matrix(rnorm(p * p), p, p)
    theta <- (theta_raw + t(theta_raw)) / 4
    theta[abs(theta) < thres] <- 0
    diag(theta) <- 0
    
  } else if (graph_type == "chain") {
    for (i in 1:(p - 1)) {
      w <- runif(1, min = 0.5, max = 1.5)
      theta[i, i + 1] <- w
      theta[i + 1, i] <- w
    }
    
  } else if (graph_type == "star") {
    for (i in 2:p) {
      w <- runif(1, min = 0.5, max = 1.5)
      theta[1, i] <- w
      theta[i, 1] <- w
    }
    
  } else if (graph_type == "block_star") {
    if (p %% block_size != 0) {
      stop("p must be divisible by block_size for block_star structure.")
    }
    
    num_blocks <- p / block_size
    for (b in 0:(num_blocks - 1)) {
      start <- b * block_size + 1
      end <- start + block_size - 1
      center <- start
      
      for (i in (start + 1):end) {
        w <- runif(1, min = 0.5, max = 1.5)
        theta[center, i] <- w
        theta[i, center] <- w
      }
    }
    
  } else if (graph_type == "sparse_block_star") {
    if (p %% block_size != 0) {
      stop("p must be divisible by block_size for sparse_block_star structure.")
    }
    if (num_spokes >= block_size) {
      stop("num_spokes must be less than block_size.")
    }
    
    num_blocks <- p / block_size
    for (b in 0:(num_blocks - 1)) {
      start <- b * block_size + 1
      end <- start + block_size - 1
      center <- start
      
      spoke_candidates <- (start + 1):end
      selected_spokes <- sample(spoke_candidates, size = num_spokes)
      
      for (i in selected_spokes) {
        w <- runif(1, min = 0.5, max = 1.5)
        theta[center, i] <- w
        theta[i, center] <- w
      }
    }
    
  } else if (graph_type == "ar1") {
    rho <- 0.5
    Sigma <- outer(1:p, 1:p, function(i, j) rho^abs(i - j))
    theta <- solve(Sigma)
    
  } else {
    stop("Unknown graph_type. Choose from 'random', 'chain', 'star', or 'block_star'.")
  }
  
  return(theta)
}


generate_theta_true_2 <- function(p, prob = c(0.95, 0.05)) {
  
  theta_true <- matrix(sample(0:1, p^2, replace = TRUE, prob = prob), nrow = p, ncol = p)

  theta_true <- theta_true * matrix(rnorm(p^2), nrow = p, ncol = p)
  
  theta_true <- pmax(theta_true,t(theta_true))
  
  diag(theta_true) <- 0
  
  return(theta_true)
}



oracle_trans_loglasso <- function(x_A_list, x_0, lambda_j = NULL, lambda_delta = NULL, lambda_scad = NULL, symmetrize_rule = "AND") {
  
  p <- ncol(x_0)
  n_0 <- nrow(x_0)
  
  if (length(x_A_list) == 0) {
    cat("No auxiliary data provided. Calling naive_loglasso directly.\n")

    theta_hat <- naive_loglasso(x_0, lambda_delta, symmetrize_rule) 
    return(list(
      beta_hat = theta_hat,
      theta_hat_A = theta_hat, 
      delta_hat_A = matrix(0, nrow = p, ncol = p),
      beta_hat_asym = theta_hat
    ))
  }
  
  n_A <- sum(sapply(x_A_list, nrow))
  x_A <- do.call(rbind, x_A_list)
  x_combined <- rbind(x_0, x_A)
  
  theta_hat_A <- matrix(0, nrow = p, ncol = p) 
  pb1 <- progress_bar$new(total = p, format = "  Step 1 [:bar] :percent eta: :eta")
  cat("=== Step 1: Initial Estimation (Asymmetric) ===\n")
  
  for (j in 1:p) {
    pb1$tick()
    y_j_combined <- x_combined[, j]
    x_minus_j_combined <- x_combined[, -j]
    
    if (is.null(lambda_j)) {
      cv_fit <- cv.glmnet(x_minus_j_combined, y_j_combined, family = "binomial", alpha = 1,
                          lambda = seq(1.5, 0.5, length.out = 11) * sqrt(log(p) / (n_0 + n_A)),
                          intercept = FALSE)
      current_lambda_j <- cv_fit$lambda.min
    } else {
      current_lambda_j <- lambda_j
    }
    
    fit_combined <- glmnet(x_minus_j_combined, y_j_combined,
                           family = 'binomial', alpha = 1, lambda = current_lambda_j,
                           intercept = FALSE)
    theta_hat_A[-j, j] <- as.numeric(coef(fit_combined, s = current_lambda_j)[-1])
  }
  cat("Step 1 completed.\n\n")
  
  delta_hat_A <- matrix(0, nrow = p, ncol = p)
  pb2 <- progress_bar$new(total = p, format = "  Step 2 [:bar] :percent eta: :eta")
  cat("=== Step 2: Bias Correction (Asymmetric) ===\n")
  
  for (j in 1:p) {
    pb2$tick()
    y_j_0 <- x_0[, j]
    x_minus_j_0 <- x_0[, -j]
    
    if (length(unique(y_j_0)) < 2 || any(table(y_j_0) < 3)) next
    
    if (is.null(lambda_delta) || is.null(lambda_scad)) {
      offset_values <- x_minus_j_0 %*% theta_hat_A[-j, j] 
      offset_values <- pmax(pmin(offset_values, 10), -10)
      
      fit_scad_cv <- tryCatch({
        cv.ncvreg(X = x_minus_j_0, y = y_j_0,
                  family = "binomial", penalty = "SCAD", offset = offset_values,
                  lambda = seq(1.5, 0.5, length.out = 11) * sqrt(log(p)/n_0), 
                  intercept = FALSE)
      }, error = function(e) NULL)
      
      if (is.null(fit_scad_cv)) next
      
      best_lambda <- fit_scad_cv$lambda.min
      tuned_lambda_delta <- best_lambda
      tuned_lambda_scad  <- 3 * best_lambda
      
    } else {
      tuned_lambda_delta <- lambda_delta
      tuned_lambda_scad  <- lambda_scad
    }
    optim_result <- optim(
      par = rep(0, p - 1), 
      fn = objective_fun_step2,
      theta_A_j = theta_hat_A[-j, j], 
      X_j = x_minus_j_0, 
      y_j = y_j_0,
      lambda_delta = tuned_lambda_delta, 
      lambda_scad = tuned_lambda_scad,
      method = "Nelder-Mead", 
      control = list(maxit = 200)
    )
    delta_hat_A[-j, j] <- optim_result$par
  }
  cat("Step 2 completed.\n\n")
  
  cat("=== Step 3: Final Estimation and Symmetrization ===\n")
  
  beta_hat_asym <- theta_hat_A + delta_hat_A
  
  beta_hat <- matrix(0, nrow = p, ncol = p)
  
  if (toupper(symmetrize_rule) == "AND") {
    for (i in 1:(p - 1)) {
      for (j in (i + 1):p) {
        if (abs(beta_hat_asym[i, j]) > 1e-6 && abs(beta_hat_asym[j, i]) > 1e-6) {
          beta_hat[i, j] <- beta_hat[j, i] <- (beta_hat_asym[i, j] + beta_hat_asym[j, i]) / 2
        }
      }
    }
  } else if (toupper(symmetrize_rule) == "OR") {
    for (i in 1:(p - 1)) {
      for (j in (i + 1):p) {
        if (abs(beta_hat_asym[i, j]) > 1e-6 || abs(beta_hat_asym[j, i]) > 1e-6) {
          beta_hat[i, j] <- beta_hat[j, i] <- (beta_hat_asym[i, j] + beta_hat_asym[j, i]) / 2
        }
      }
    }
  } else {
    beta_hat <- (beta_hat_asym + t(beta_hat_asym)) / 2
  }
  
  cat("Symmetrization with '", toupper(symmetrize_rule), "' rule completed.\n\n", sep="")
  
  return(list(
    beta_hat = beta_hat,
    theta_hat_A = theta_hat_A,     
    delta_hat_A = delta_hat_A,     
    beta_hat_asym = beta_hat_asym  
  ))
}


naive_loglasso <- function(x_0, lambda_delta = NULL, symmetrize_rule = "AND") {
  p <- ncol(x_0)
  n_0 <- nrow(x_0)
  
  theta_hat_primary <- matrix(0, nrow = p, ncol = p)
  
  pb <- progress_bar$new(
    total = p,
    format = "  Naive Lasso [:bar] :percent eta: :eta"
  )
  cat("=== Naive-LogLasso Estimation (Asymmetric) ===\n")
  
  for (j in 1:p) {
    pb$tick()
    
    y_j <- x_0[, j]
    x_minus_j <- x_0[, -j]
    
    tab <- table(y_j)
    if (length(tab) < 2 || any(tab < 3)) {
      next
    }
    
    if (is.null(lambda_delta)) {
      cv_fit <- tryCatch({
        cv.glmnet(x_minus_j, y_j, family = "binomial", alpha = 1,
                  lambda = seq(1.5, 0.5, length.out = 11) * sqrt(log(p)/n_0),
                  intercept = FALSE)
      }, error = function(e) { NULL })
      
      if (is.null(cv_fit)) next
      current_lambda <- cv_fit$lambda.min
    } else {
      current_lambda <- lambda_delta
    }
    
    fit <- tryCatch({
      glmnet(x_minus_j, y_j, family = 'binomial', alpha = 1,
             lambda = current_lambda, intercept = FALSE)
    }, error = function(e) { NULL })
    
    if (is.null(fit)) next
    
    beta_hat <- coef(fit, s = current_lambda)
    theta_hat_primary[-j, j] <- as.numeric(beta_hat[-1])
  }
  cat("\nAsymmetric estimation completed.\n")
  
  cat(paste0("Applying '", toupper(symmetrize_rule), "' rule for symmetrization...\n"))
  
  theta_hat_sym <- matrix(0, nrow = p, ncol = p)
  
  if (toupper(symmetrize_rule) == "AND") {
    for (i in 1:(p - 1)) {
      for (j in (i + 1):p) {
        if (abs(theta_hat_primary[i, j]) > 1e-6 && abs(theta_hat_primary[j, i]) > 1e-6) {
          theta_hat_sym[i, j] <- theta_hat_sym[j, i] <- (theta_hat_primary[i, j] + theta_hat_primary[j, i]) / 2
        }
      }
    }
  } else if (toupper(symmetrize_rule) == "OR") {
    for (i in 1:(p - 1)) {
      for (j in (i + 1):p) {
        if (abs(theta_hat_primary[i, j]) > 1e-6 || abs(theta_hat_primary[j, i]) > 1e-6) {
          theta_hat_sym[i, j] <- theta_hat_sym[j, i] <- (theta_hat_primary[i, j] + theta_hat_primary[j, i]) / 2
        }
      }
    }
  } else {
    cat("Warning: Unknown rule. Using simple averaging for symmetrization.\n")
    theta_hat_sym <- (theta_hat_primary + t(theta_hat_primary)) / 2
  }
  
  cat("Naive-LogLasso completed.\n\n")
  
  return(theta_hat_sym)
}


