

#################################################
### Real data analysis - DepMap Mutation Data ###
#################################################

library(depmap)
library(tidyr)
library(dplyr)

set.seed(2025) 

mut_data <- depmap_mutationCalls()
meta <- depmap::depmap_metadata() %>% select(depmap_id, primary_disease)

bin_mut <- mut_data %>%
  filter(!is.na(gene_name)) %>%
  mutate(mutated = 1) %>%
  select(depmap_id, gene_name, mutated) %>%
  distinct() %>%
  pivot_wider(names_from = gene_name,
              values_from = mutated,
              values_fill = 0)

df_full <- left_join(bin_mut, meta, by = "depmap_id")


target_disease <- "Brain Cancer"

df_target <- df_full %>%
  filter(primary_disease == target_disease)

min_samples <- 20
disease_counts <- df_full %>%
  filter(!is.na(primary_disease) & primary_disease != target_disease) %>%
  group_by(primary_disease) %>%
  summarise(n = n(), .groups = 'drop') %>%
  filter(n >= min_samples)

aux_types <- disease_counts$primary_disease

x_A_list_full <- lapply(aux_types, function(d){
  df_full %>% 
    filter(primary_disease == d) %>%
    select(-depmap_id, -primary_disease) %>%
    as.matrix()
})

misclassification_error <- function(x, theta_hat) {
  n <- nrow(x)
  p <- ncol(x)
  total_errors <- 0
  total_predictions <- 0
  
  for (j in 1:p) {
    x_minus_j <- x[, -j, drop = FALSE]
    theta_j <- theta_hat[j, -j]
    
    logits <- x_minus_j %*% theta_j
    probs <- 1 / (1 + exp(-logits))
    
    preds <- ifelse(probs >= 0.5, 1, 0)
    true_vals <- x[, j]
    
    total_errors <- total_errors + sum(preds != true_vals)
    total_predictions <- total_predictions + length(true_vals)
  }
  
  return(total_errors / total_predictions)
}

k_folds <- 5
n_samples <- nrow(df_target)
cat(sprintf("### Starting %d-Fold Cross-Validation for %d samples... ###\n", k_folds, n_samples))

folds <- sample(cut(seq(1, n_samples), breaks = k_folds, labels = FALSE))

results_matrix <- matrix(NA, nrow = k_folds, ncol = 3)
colnames(results_matrix) <- c("Trans-Ising", "Pooled-Trans-Ising", "Naive-LogLasso")

for (i in 1:k_folds) {
  
  cat(sprintf("\n===== Processing Fold %d / %d =====\n", i, k_folds))
  test_indices  <- which(folds == i)
  train_indices <- which(folds != i)
  
  df_train_temp <- df_target[train_indices, ]
  
  gene_columns <- setdiff(names(df_train_temp), c("depmap_id", "primary_disease"))
  freq <- colSums(df_train_temp[, gene_columns])
  top_genes <- names(sort(freq, decreasing = TRUE))[1:200]
  
  x_train <- df_target[train_indices, top_genes] %>% as.matrix()
  x_test  <- df_target[test_indices, top_genes] %>% as.matrix()
  
  x_A_list <- lapply(x_A_list_full, function(aux_matrix) {
    valid_genes <- intersect(top_genes, colnames(aux_matrix))
    aux_matrix[, valid_genes, drop = FALSE]
  })
  
  cat("Training Trans-Ising...\n")
  trans_results <- trans_loglasso(x_A_list = x_A_list, x_0 = x_train)
  beta_hat_trans <- trans_results$beta_hat
  
  cat("Training Pooled-Trans-Ising...\n")
  pooled_results <- oracle_trans_loglasso(x_A_list = x_A_list, x_0 = x_train)
  beta_hat_pooled <- pooled_results$beta_hat

  cat("Training Naive-LogLasso...\n")
  beta_hat_naive <- naive_loglasso(x_0 = x_train) 
  
  cat("Evaluating models for fold", i, "...\n")
  error_trans <- misclassification_error(x = x_test, theta_hat = beta_hat_trans)
  error_pooled <- misclassification_error(x = x_test, theta_hat = beta_hat_pooled)
  error_naive <- misclassification_error(x = x_test, theta_hat = beta_hat_naive)
  
  results_matrix[i, "Trans-Ising"] <- error_trans
  results_matrix[i, "Pooled-Trans-Ising"] <- error_pooled
  results_matrix[i, "Naive-LogLasso"] <- error_naive
  
  cat(sprintf("Fold %d Results: Trans=%.4f, Pooled=%.4f, Naive=%.4f\n", i, error_trans, error_pooled, error_naive))
}

final_errors <- colMeans(results_matrix, na.rm = TRUE)
sd_errors <- apply(results_matrix, 2, sd, na.rm = TRUE)

cat("\n\n===== Final %d-Fold CV Results (Averaged) =====\n", k_folds)
cat(sprintf("Trans-Ising Misclassification Error        : %.4f (sd = %.4f)\n", final_errors["Trans-Ising"], sd_errors["Trans-Ising"]))
cat(sprintf("Pooled-Trans-Ising Misclassification Error : %.4f (sd = %.4f)\n", final_errors["Pooled-Trans-Ising"], sd_errors["Pooled-Trans-Ising"]))
cat(sprintf("Naive-LogLasso Misclassification Error     : %.4f (sd = %.4f)\n", final_errors["Naive-LogLasso"], sd_errors["Naive-LogLasso"]))
cat("==================================================\n")





### Experiment for all target diseases ###

library(depmap)
library(tidyr)
library(dplyr)

set.seed(2025)

mut_data <- depmap_mutationCalls()
meta <- depmap::depmap_metadata() %>% select(depmap_id, primary_disease)

bin_mut <- mut_data %>%
  filter(!is.na(gene_name)) %>%
  mutate(mutated = 1) %>%
  select(depmap_id, gene_name, mutated) %>%
  distinct() %>%
  pivot_wider(names_from = gene_name,
              values_from = mutated,
              values_fill = 0)

df_full <- left_join(bin_mut, meta, by = "depmap_id")

disease_summary <- df_full %>% 
  group_by(primary_disease) %>% 
  summarise(n = n(), .groups = 'drop') %>%
  filter(n >= 20 & !is.na(primary_disease))

all_target_diseases <- disease_summary$primary_disease

all_results_list <- list()

k_folds <- 5

misclassification_error <- function(x, theta_hat) {
  n <- nrow(x)
  p <- ncol(x)
  total_errors <- 0
  total_predictions <- 0
  
  for (j in 1:p) {
    if (p == 1) next
    
    x_minus_j <- x[, -j, drop = FALSE]
    theta_j <- theta_hat[j, -j]
    
    logits <- x_minus_j %*% theta_j
    probs <- 1 / (1 + exp(-logits))
    
    preds <- ifelse(probs >= 0.5, 1, 0)
    true_vals <- x[, j]
    
    total_errors <- total_errors + sum(preds != true_vals)
    total_predictions <- total_predictions + length(true_vals)
  }
  
  if (total_predictions == 0) return(NA)
  return(total_errors / total_predictions)
}

for (current_disease in all_target_diseases) {
  
  cat(sprintf("\n\n############################################################\n"))
  cat(sprintf("###   STARTING EXPERIMENT FOR: %s   ###\n", current_disease))
  cat(sprintf("############################################################\n"))
  
  target_disease <- current_disease
  
  df_target <- df_full %>%
    filter(primary_disease == target_disease)
  
  min_samples <- 20
  disease_counts <- df_full %>%
    filter(!is.na(primary_disease) & primary_disease != target_disease) %>%
    group_by(primary_disease) %>%
    summarise(n = n(), .groups = 'drop') %>%
    filter(n >= min_samples)
  
  aux_types <- disease_counts$primary_disease
  
  x_A_list_full <- lapply(aux_types, function(d){
    df_full %>% 
      filter(primary_disease == d) %>%
      select(-depmap_id, -primary_disease) %>%
      as.matrix()
  })
  
  n_samples <- nrow(df_target)
  cat(sprintf("### Starting %d-Fold Cross-Validation for %d samples... ###\n", k_folds, n_samples))
  
  folds <- sample(cut(seq(1, n_samples), breaks = k_folds, labels = FALSE))
  results_matrix <- matrix(NA, nrow = k_folds, ncol = 3)
  colnames(results_matrix) <- c("Trans-Ising", "Pooled-Trans-Ising", "Naive-LogLasso")
  
  for (i in 1:k_folds) {
    cat(sprintf("\n===== Processing Fold %d / %d for %s =====\n", i, k_folds, current_disease))
    
    test_indices  <- which(folds == i)
    train_indices <- which(folds != i)
    
    df_train_temp <- df_target[train_indices, ]
    
    gene_columns <- setdiff(names(df_train_temp), c("depmap_id", "primary_disease"))
    freq <- colSums(df_train_temp[, gene_columns])
    top_genes <- names(sort(freq, decreasing = TRUE))[1:200]
    
    x_train <- df_target[train_indices, top_genes] %>% as.matrix()
    x_test  <- df_target[test_indices, top_genes] %>% as.matrix()
    
    x_A_list <- lapply(x_A_list_full, function(aux_matrix) {
      valid_genes <- intersect(top_genes, colnames(aux_matrix))
      aux_matrix[, valid_genes, drop = FALSE]
    })
    
    cat("Training Trans-Ising...\n")
    trans_results <- trans_loglasso(x_A_list = x_A_list, x_0 = x_train)
    beta_hat_trans <- trans_results$beta_hat
    
    cat("Training Pooled-Trans-Ising...\n")
    pooled_results <- oracle_trans_loglasso(x_A_list = x_A_list, x_0 = x_train)
    beta_hat_pooled <- pooled_results$beta_hat
    
    cat("Training Naive-LogLasso...\n")
    beta_hat_naive <- naive_loglasso(x_0 = x_train)
    
    cat("Evaluating models for fold", i, "...\n")
    error_trans <- misclassification_error(x = x_test, theta_hat = beta_hat_trans)
    error_pooled <- misclassification_error(x = x_test, theta_hat = beta_hat_pooled)
    error_naive <- misclassification_error(x = x_test, theta_hat = beta_hat_naive)
    
    results_matrix[i, "Trans-Ising"] <- error_trans
    results_matrix[i, "Pooled-Trans-Ising"] <- error_pooled
    results_matrix[i, "Naive-LogLasso"] <- error_naive
    
    cat(sprintf("Fold %d Results: Trans=%.4f, Pooled=%.4f, Naive=%.4f\n", i, error_trans, error_pooled, error_naive))
  }
  
  final_errors <- colMeans(results_matrix, na.rm = TRUE)
  sd_errors <- apply(results_matrix, 2, sd, na.rm = TRUE)
  
  all_results_list[[current_disease]] <- data.frame(
    Model = names(final_errors),
    Mean_Error = final_errors,
    SD_Error = sd_errors,
    N_Samples = n_samples
  )
}

cat("\n\n===========================================================\n")
cat("###      ALL EXPERIMENTS COMPLETED - FINAL SUMMARY      ###\n")
cat("===========================================================\n\n")

final_summary_df <- do.call(rbind, lapply(names(all_results_list), function(name) {
  cbind(Disease = name, all_results_list[[name]])
}))

final_summary_df

write.csv(final_summary_df, 
          file = "depmap_trans_ising_results.csv", 
          row.names = FALSE) 

rownames(final_summary_df) <- NULL

print(final_summary_df)



### Making plots for DepMap Mutation Data ###

file_path <- "depmap_trans_ising_results.csv"
df <- read.csv(file_path)

relative_df <- df %>%
  group_by(Disease) %>%
  mutate(
    Baseline_Mean_Error = Mean_Error[Model == "Naive-LogLasso"],
    Baseline_SD_Error = SD_Error[Model == "Naive-LogLasso"]
  ) %>%
  ungroup() %>%
  filter(Model != "Naive-LogLasso") %>%
  mutate(
    Relative_Error = Mean_Error / Baseline_Mean_Error,
    Propagated_SD = sqrt((SD_Error / Mean_Error)^2 + (Baseline_SD_Error / Baseline_Mean_Error)^2) * Relative_Error
  )


ggplot(relative_df, aes(x = Disease, y = Relative_Error, fill = Model)) +
  geom_bar(stat = "identity", position = position_dodge(width = 0.9), alpha = 0.8) +
  
  geom_errorbar(
    aes(ymin = Relative_Error - Propagated_SD, ymax = Relative_Error + Propagated_SD),
    width = 0.25,
    position = position_dodge(width = 0.9)
  ) +

  geom_hline(yintercept = 1, linetype = "dashed", color = "red", size = 1) +
  
  labs(
    title = "Comparison of Model Performance",
    x = "Disease",
    y = "Relative Misclassification Error",
    fill = "Model"
  ) +
  
  theme_minimal(base_family = "Times New Roman", base_size = 20) +
  theme(
    plot.title = element_text(hjust = 0.5, face = "bold", size = 27),
    
    axis.title = element_text(face = "bold", size = 27),
    axis.text.y = element_text(size = 23),    
    
    axis.text.x = element_text(angle = 60, hjust = 1, face = "bold", size = 27),
    legend.position = "bottom",
    
    legend.title = element_blank(),        
    legend.text = element_text(size = 27), 
    legend.key.size = unit(1.0, "cm"),     
    legend.margin = margin(t = 20)         
  )


p <- ggplot(relative_df, aes(x = Disease, y = Relative_Error, fill = Model)) +
  geom_bar(stat = "identity", position = position_dodge(width = 0.9), alpha = 0.8) +
  geom_hline(yintercept = 1, linetype = "dashed", color = "red", size = 1) +
  labs(
    title = "Comparison of Model Performance",
    x = "Disease",
    y = "Relative Misclassification Error",
    fill = "Model"
  ) +
  theme_minimal(base_family = "Times New Roman", base_size = 20) +
  theme(
    plot.title = element_text(hjust = 0.5, face = "bold", size = 30),
    
    axis.title = element_text(face = "bold", size = 27),
    axis.text.y = element_text(size = 23),    
    
    axis.text.x = element_text(angle = 60, hjust = 1, face = "bold", size = 27),
    legend.position = "bottom",
    
    legend.title = element_blank(),        
    legend.text = element_text(size = 30), 
    legend.key.size = unit(1.0, "cm"),     
    legend.margin = margin(t = 20)      
  )

ggsave(
  filename = "realdata_mutation.png", 
  plot = p,                             
  width = 40,                           
  height = 30,                          
  units = "cm",                         
  dpi = 600 
)




### Informative Sources for DepMap Mutation Data ###

informative_source_results <- list()

set.seed(2025)

for (current_disease in all_target_diseases) {
  
  cat(sprintf("\n\n############################################################\n"))
  cat(sprintf("###  STARTING SOURCE DETECTION FOR: %s  ###\n", current_disease))
  cat(sprintf("############################################################\n"))
  
  target_disease <- current_disease
  
  df_target <- df_full %>%
    filter(primary_disease == target_disease)
  
  gene_columns <- setdiff(names(df_target), c("depmap_id", "primary_disease"))
  freq <- colSums(df_target[, gene_columns])
  top_genes <- names(sort(freq, decreasing = TRUE))[1:200]
  
  x_0_full <- df_target[, top_genes] %>% as.matrix()
  
  min_samples <- 20
  disease_counts <- df_full %>%
    filter(!is.na(primary_disease) & primary_disease != target_disease) %>%
    group_by(primary_disease) %>%
    summarise(n = n(), .groups = 'drop') %>%
    filter(n >= min_samples)
  
  aux_types <- disease_counts$primary_disease
  
  x_A_list_full <- lapply(aux_types, function(d){
    df_full %>% 
      filter(primary_disease == d) %>%
      select(all_of(top_genes)) %>% 
      as.matrix()
  })
  
  cat(sprintf("### Running trans_loglasso for %s... ###\n", current_disease))
  
  source_detection_run <- trans_loglasso(x_A_list = x_A_list_full, x_0 = x_0_full)
  
  informative_indices <- source_detection_run$informative_set
  
  if (length(informative_indices) > 0) {
    informative_diseases <- aux_types[informative_indices]
    
    informative_source_results[[current_disease]] <- data.frame(
      TargetDisease = current_disease,
      InformativeSource = informative_diseases,
      Loss_Difference = source_detection_run$loss_diffs[informative_indices]
    )
    
    cat(sprintf("### Found %d informative sources for %s: %s ###\n", 
                length(informative_diseases), 
                current_disease, 
                paste(informative_diseases, collapse = ", ")))
    
  } else {
    cat(sprintf("### Found 0 informative sources for %s ###\n", current_disease))
    
    informative_source_results[[current_disease]] <- data.frame(
      TargetDisease = current_disease,
      InformativeSource = NA, 
      Loss_Difference = NA    
    )
  }
}

final_source_df <- do.call(rbind, informative_source_results)
rownames(final_source_df) <- NULL

print(final_source_df)

write.csv(final_source_df, "informative_source_file.csv", row.names = FALSE)


