
#########################################
### Real data analysis - MovieLens 1M ###
#########################################

library(readr)
library(dplyr)
library(tidyr)

# download the dataset from https://www.kaggle.com/datasets/odedgolden/movielens-1m-dataset?resource=download
data_dir <- ""

top_p_movies     <- 100

rating_threshold <- 4

group_col   <- "Age"    
target_group <- 25      

min_users_aux <- 200

ratings_path <- file.path(data_dir, "ratings.dat")
users_path   <- file.path(data_dir, "users.dat")
movies_path  <- file.path(data_dir, "movies.dat")

cat("Reading ratings.dat ...\n")
ratings <- read_delim(
  ratings_path,
  delim = "::",
  col_names = c("UserID", "MovieID", "Rating", "Timestamp"),
  escape_double = FALSE,
  trim_ws = TRUE,
  progress = FALSE
)

cat("Reading users.dat ...\n")
users <- read_delim(
  users_path,
  delim = "::",
  col_names = c("UserID", "Gender", "Age", "Occupation", "ZipCode"),
  escape_double = FALSE,
  trim_ws = TRUE,
  progress = FALSE
)

cat("Reading movies.dat ...\n")
movies <- read_delim(
  movies_path,
  delim = "::",
  col_names = c("MovieID", "Title", "Genres"),
  escape_double = FALSE,
  trim_ws = TRUE,
  progress = FALSE
)

cat(sprintf("Loaded: %d ratings, %d users, %d movies\n",
            nrow(ratings), nrow(users), nrow(movies)))

cat(sprintf("Binarizing ratings with threshold >= %d ...\n", rating_threshold))

bin_ratings <- ratings %>%
  filter(Rating >= rating_threshold) %>%        
  mutate(
    MovieID = as.character(MovieID),            
    liked   = 1L
  ) %>%
  select(UserID, MovieID, liked) %>%
  distinct() %>%                                
  pivot_wider(
    names_from  = MovieID,
    values_from = liked,
    values_fill = list(liked = 0L)
  )

cat(sprintf("Binary user-movie matrix: %d users × %d movies (before join)\n",
            nrow(bin_ratings), ncol(bin_ratings) - 1))

df_full <- bin_ratings %>%
  inner_join(users, by = "UserID")

cat(sprintf("After join: %d users × %d columns\n",
            nrow(df_full), ncol(df_full)))

movie_cols <- setdiff(
  colnames(df_full),
  c("UserID", "Gender", "Age", "Occupation", "ZipCode")
)

cat(sprintf("Detected movie columns (candidate nodes): p = %d\n", length(movie_cols)))


if (!group_col %in% colnames(df_full)) {
  stop(sprintf("group_col '%s' not found in df_full!", group_col))
}

cat(sprintf("\nUsing group_col = '%s', target_group = '%s'\n",
            group_col, as.character(target_group)))

idx_target <- which(df_full[[group_col]] == target_group)
n_target   <- length(idx_target)

if (n_target == 0) {
  stop(sprintf("No users found for target group: %s = %s",
               group_col, as.character(target_group)))
}

cat(sprintf("Target group '%s = %s': %d users\n",
            group_col, as.character(target_group), n_target))

X_target_full <- df_full[idx_target, movie_cols, drop = FALSE] %>%
  as.matrix()
rownames(X_target_full) <- df_full$UserID[idx_target]

movie_freq_target <- colSums(X_target_full)
sorted_movies <- sort(movie_freq_target, decreasing = TRUE)
top_movies <- names(sorted_movies)[1:min(top_p_movies, length(sorted_movies))]
p <- length(top_movies)

cat(sprintf("Selected top %d movies for target (p = %d nodes)\n",
            top_p_movies, p))

X0 <- X_target_full[, top_movies, drop = FALSE]

cat(sprintf("Final X0 matrix: %d users × %d movies\n",
            nrow(X0), ncol(X0)))

group_counts <- df_full %>%
  group_by(.data[[group_col]]) %>%
  summarise(n_users = n(), .groups = "drop") %>%
  arrange(desc(n_users))

cat("\nGroup sizes:\n")
print(group_counts)

aux_groups <- group_counts %>%
  filter(.data[[group_col]] != target_group,
         n_users >= min_users_aux) %>%
  pull(.data[[group_col]])

cat("\nAuxiliary groups (candidates):\n")
print(aux_groups)

if (length(aux_groups) == 0) {
  warning("No auxiliary groups satisfy min_users_aux; x_A_list will be empty.")
}

x_A_list <- lapply(aux_groups, function(g) {
  idx <- which(df_full[[group_col]] == g)
  M_full <- df_full[idx, movie_cols, drop = FALSE] %>%
    as.matrix()
  
  common <- intersect(top_movies, colnames(M_full))
  M_sub  <- matrix(0L, nrow = nrow(M_full), ncol = length(top_movies))
  colnames(M_sub) <- top_movies
  
  if (length(common) > 0) {
    M_sub[, common] <- M_full[, common, drop = FALSE]
  }
  
  rownames(M_sub) <- df_full$UserID[idx]
  M_sub
})

names(x_A_list) <- as.character(aux_groups)

cat("\n--- Final matrices ---\n")
cat(sprintf("X0 (target '%s = %s'): %d × %d\n",
            group_col, as.character(target_group), nrow(X0), ncol(X0)))

if (length(x_A_list) > 0) {
  aux_dims <- sapply(x_A_list, function(M) paste0(nrow(M), "×", ncol(M)))
  print(aux_dims)
} else {
  cat("x_A_list is empty (no auxiliary groups).\n")
}




### Single target experiment for MovieLens ###

library(dplyr)
library(ggplot2)

cat("=== Basic info ===\n")
cat(sprintf("Target matrix X0 : %d users × %d movies\n", nrow(X0), ncol(X0)))
if (length(x_A_list) > 0) {
  aux_dims <- sapply(x_A_list, function(M) paste0(nrow(M), "×", ncol(M)))
  cat("Auxiliary groups (rows×cols):\n")
  print(aux_dims)
} else {
  cat("x_A_list is empty (no auxiliary groups)\n")
}

if (!exists("misclassification_error")) {
  misclassification_error <- function(x, theta_hat) {
    n <- nrow(x)
    p <- ncol(x)
    total_errors <- 0
    total_predictions <- 0
    
    if (p <= 1 || is.null(theta_hat) ||
        nrow(theta_hat) != p || ncol(theta_hat) != p) {
      return(NA_real_)
    }
    
    for (j in 1:p) {
      x_minus_j <- x[, -j, drop = FALSE]
      theta_j   <- theta_hat[j, -j]
      
      logits <- x_minus_j %*% theta_j
      probs  <- 1 / (1 + exp(-logits))
      
      preds      <- ifelse(probs >= 0.5, 1L, 0L)
      true_vals  <- x[, j]
      
      total_errors      <- total_errors + sum(preds != true_vals)
      total_predictions <- total_predictions + length(true_vals)
    }
    
    if (total_predictions == 0) return(NA_real_)
    total_errors / total_predictions
  }
}

set.seed(2025)

k_folds <- 5      
n0      <- nrow(X0)

cat(sprintf("\n=== Starting %d-fold CV on MovieLens target (n0 = %d, p = %d) ===\n",
            k_folds, n0, ncol(X0)))

fold_id <- sample(cut(seq_len(n0), breaks = k_folds, labels = FALSE))

results <- matrix(NA_real_, nrow = k_folds, ncol = 3)
colnames(results) <- c("Trans-Ising", "Pooled-Trans-Ising", "Naive-LogLasso")

for (i in seq_len(k_folds)) {
  cat(sprintf("\n===== Fold %d / %d =====\n", i, k_folds))
  
  test_idx  <- which(fold_id == i)
  train_idx <- setdiff(seq_len(n0), test_idx)
  
  X0_train <- X0[train_idx, , drop = FALSE]
  X0_test  <- X0[test_idx,  , drop = FALSE]
  
  cat(sprintf("[Fold %d] X0_train: %d × %d, X0_test: %d × %d\n",
              i, nrow(X0_train), ncol(X0_train),
              nrow(X0_test),  ncol(X0_test)))
  
  p   <- ncol(X0_train)                 
  n_0 <- nrow(X0_train)                 
  n_A <- sum(sapply(x_A_list, nrow))    
  
  lambda_j_val <- 1.3 * sqrt(log(p) / (n_0 + n_A))
  
  tryCatch({
    cat("Training Trans-Ising...\n")
    fit_trans <- trans_loglasso(
      x_A_list = x_A_list, 
      x_0      = X0_train,
      lambda_j = lambda_j_val
    )
    beta_hat_trans <- fit_trans$beta_hat
    
    cat("Training Pooled-Trans-Ising...\n")
    fit_pooled <- oracle_trans_loglasso(
      x_A_list = x_A_list, 
      x_0      = X0_train,
      lambda_j = lambda_j_val
    )
    beta_hat_pooled <- fit_pooled$beta_hat
    
    cat("Training Naive-LogLasso...\n")
    beta_hat_naive <- naive_loglasso(x_0 = X0_train)
    
    cat("Evaluating on test set...\n")
    err_trans  <- misclassification_error(X0_test, beta_hat_trans)
    err_pooled <- misclassification_error(X0_test, beta_hat_pooled)
    err_naive  <- misclassification_error(X0_test, beta_hat_naive)
    
    results[i, ] <- c(err_trans, err_pooled, err_naive)
    
    cat(sprintf("Fold %d Misclassification Error:\n  Trans   = %.4f\n  Pooled  = %.4f\n  Naive   = %.4f\n",
                i, err_trans, err_pooled, err_naive))
    
  }, error = function(e) {
    cat(sprintf("  [Fold %d] ERROR: %s\n", i, e$message))
  })
}

cat("\n=== Cross-Validation Summary ===\n")
print(results)

avg <- colMeans(results, na.rm = TRUE)
sdv <- apply(results, 2, sd, na.rm = TRUE)

summary_df <- data.frame(
  Model      = names(avg),
  Mean_Error = as.numeric(avg),
  SD_Error   = as.numeric(sdv)
)

print(summary_df)

rel_df <- summary_df %>%
  mutate(Baseline = Mean_Error[Model == "Naive-LogLasso"]) %>%
  filter(Model != "Naive-LogLasso") %>%
  mutate(Relative = Mean_Error / Baseline)

cat("\n=== Relative Error (vs Naive-LogLasso) ===\n")
print(rel_df)



if (nrow(rel_df) > 0) {
  ggplot(rel_df, aes(x = Model, y = Relative, fill = Model)) +
    geom_bar(stat = "identity", width = 0.6, alpha = 0.85) +
    geom_hline(yintercept = 1, linetype = "dashed") +
    labs(
      title = "MovieLens (Target Group) — Relative Misclassification Error",
      x = NULL, 
      y = "Relative Error vs Naive"
    ) +
    theme_minimal(base_size = 14) +
    theme(legend.position = "none")
}



### MovieLens experiment for all targets ###

library(dplyr)

group_col        <- "Age"   
top_p_movies     <- 100     
min_users_target <- 200     
min_users_aux    <- 200     
k_folds          <- 5       
lambda_mult      <- 1.3     

set.seed(2025)

age_summary <- df_full %>%
  group_by(Age) %>%
  summarise(n_users = n(), .groups = "drop") %>%
  arrange(desc(n_users))

cat("=== Age group sizes ===\n")
print(age_summary)

target_ages <- age_summary %>%
  filter(n_users >= min_users_target) %>%
  pull(Age)

cat("\n=== Target Age groups (n_users >= ", min_users_target, ") ===\n", sep = "")
print(target_ages)

run_trans_ising_for_target <- function(target_age,
                                       df_full,
                                       movie_cols,
                                       group_col = "Age",
                                       top_p_movies = 100,
                                       min_users_aux = 200,
                                       k_folds = 5,
                                       lambda_mult = 1.3) {
  
  cat("\n\n############################################################\n")
  cat(sprintf("###   STARTING EXPERIMENT FOR TARGET: %s = %s   ###\n",
              group_col, as.character(target_age)))
  cat("############################################################\n")
  
  idx_target <- which(df_full[[group_col]] == target_age)
  n_target   <- length(idx_target)
  
  if (n_target == 0) {
    cat("  -> No users in this target group. Skipping.\n")
    return(NULL)
  }
  
  cat(sprintf("[Target] %s = %s : %d users\n",
              group_col, as.character(target_age), n_target))
  
  X_target_full <- df_full[idx_target, movie_cols, drop = FALSE] %>%
    as.matrix()
  rownames(X_target_full) <- df_full$UserID[idx_target]
  
  movie_freq_target <- colSums(X_target_full)
  sorted_movies     <- sort(movie_freq_target, decreasing = TRUE)
  top_movies        <- names(sorted_movies)[1:min(top_p_movies, length(sorted_movies))]
  p <- length(top_movies)
  
  if (p < 2) {
    cat(sprintf("  -> Too few movies selected (p = %d). Skipping.\n", p))
    return(NULL)
  }
  
  cat(sprintf("  -> Selected top %d movies (p = %d nodes)\n",
              top_p_movies, p))
  
  X0 <- X_target_full[, top_movies, drop = FALSE]
  n0 <- nrow(X0)
  cat(sprintf("  -> Final X0: %d users × %d movies\n", n0, p))
  
  group_counts <- df_full %>%
    group_by(.data[[group_col]]) %>%
    summarise(n_users = n(), .groups = "drop") %>%
    arrange(desc(n_users))
  
  aux_groups <- group_counts %>%
    filter(.data[[group_col]] != target_age,
           n_users >= min_users_aux) %>%
    pull(.data[[group_col]])
  
  cat("\n  Auxiliary groups (candidates):\n")
  print(aux_groups)
  
  if (length(aux_groups) == 0) {
    cat("  -> No auxiliary groups with enough users. Will only fit Naive.\n")
  }
  
  x_A_list <- lapply(aux_groups, function(g) {
    idx <- which(df_full[[group_col]] == g)
    M_full <- df_full[idx, movie_cols, drop = FALSE] %>%
      as.matrix()
    
    common <- intersect(top_movies, colnames(M_full))
    M_sub  <- matrix(0L, nrow = nrow(M_full), ncol = length(top_movies))
    colnames(M_sub) <- top_movies
    
    if (length(common) > 0) {
      M_sub[, common] <- M_full[, common, drop = FALSE]
    }
    
    rownames(M_sub) <- df_full$UserID[idx]
    M_sub
  })
  names(x_A_list) <- as.character(aux_groups)
  
  if (length(x_A_list) > 0) {
    aux_dims <- sapply(x_A_list, function(M) paste0(nrow(M), "×", ncol(M)))
    cat("\n  x_A_list dimensions (rows×cols):\n")
    print(aux_dims)
  }
  
  cat(sprintf("\n=== %s=%s : Starting %d-fold CV (n0 = %d, p = %d) ===\n",
              group_col, as.character(target_age), k_folds, n0, p))
  
  fold_id <- sample(cut(seq_len(n0), breaks = k_folds, labels = FALSE))
  results <- matrix(NA_real_, nrow = k_folds, ncol = 3)
  colnames(results) <- c("Trans-Ising", "Pooled-Trans-Ising", "Naive-LogLasso")
  
  for (i in seq_len(k_folds)) {
    cat(sprintf("\n----- Fold %d / %d -----\n", i, k_folds))
    
    test_idx  <- which(fold_id == i)
    train_idx <- setdiff(seq_len(n0), test_idx)
    
    X0_train <- X0[train_idx, , drop = FALSE]
    X0_test  <- X0[test_idx,  , drop = FALSE]
    
    cat(sprintf("[Fold %d] X0_train: %d×%d, X0_test: %d×%d\n",
                i, nrow(X0_train), ncol(X0_train),
                nrow(X0_test),  ncol(X0_test)))
    
    n_0 <- nrow(X0_train)
    n_A <- if (length(x_A_list) > 0) sum(sapply(x_A_list, nrow)) else 0L
    lambda_j_val <- lambda_mult * sqrt(log(p) / (n_0 + max(n_A, 1)))
    lambda_j_val <- NULL
    
    tryCatch({
      cat("  Training Naive-LogLasso...\n")
      beta_hat_naive <- naive_loglasso(x_0 = X0_train)
      
      if (length(x_A_list) > 0) {
        cat("  Training Trans-Ising...\n")
        fit_trans <- trans_loglasso(
          x_A_list = x_A_list,
          x_0      = X0_train,
          lambda_j = lambda_j_val
        )
        beta_hat_trans <- fit_trans$beta_hat
        
        cat("  Training Pooled-Trans-Ising...\n")
        fit_pooled <- oracle_trans_loglasso(
          x_A_list = x_A_list,
          x_0      = X0_train,
          lambda_j = lambda_j_val
        )
        beta_hat_pooled <- fit_pooled$beta_hat
      } else {
        beta_hat_trans  <- NULL
        beta_hat_pooled <- NULL
      }
      
      cat("  Evaluating on test set...\n")
      err_naive <- misclassification_error(X0_test, beta_hat_naive)
      if (!is.null(beta_hat_trans)) {
        err_trans  <- misclassification_error(X0_test, beta_hat_trans)
        err_pooled <- misclassification_error(X0_test, beta_hat_pooled)
      } else {
        err_trans  <- NA_real_
        err_pooled <- NA_real_
      }
      
      results[i, ] <- c(err_trans, err_pooled, err_naive)
      
      cat(sprintf("  Fold %d Errors: Trans=%.4f, Pooled=%.4f, Naive=%.4f\n",
                  i, err_trans, err_pooled, err_naive))
      
    }, error = function(e) {
      cat(sprintf("  [Fold %d] ERROR: %s\n", i, e$message))
    })
  }
  
  avg <- colMeans(results, na.rm = TRUE)
  sdv <- apply(results, 2, sd, na.rm = TRUE)
  
  summary_df <- data.frame(
    Model      = names(avg),
    Mean_Error = as.numeric(avg),
    SD_Error   = as.numeric(sdv),
    Target     = target_age,
    n_target   = n0,
    p          = p
  )
  
  cat("\n=== Summary for target Age =", target_age, "===\n")
  print(summary_df)
  
  return(list(
    target_age   = target_age,
    summary_df   = summary_df,
    fold_results = results,
    aux_groups   = aux_groups
  ))
}

all_results <- list()

for (ag in target_ages) {
  res <- run_trans_ising_for_target(
    target_age     = ag,
    df_full        = df_full,
    movie_cols     = movie_cols,
    group_col      = group_col,
    top_p_movies   = top_p_movies,
    min_users_aux  = min_users_aux,
    k_folds        = k_folds,
    lambda_mult    = lambda_mult
  )
  
  all_results[[as.character(ag)]] <- res
}

summary_list_ml <- lapply(all_results, function(x) {
  if (is.null(x)) return(NULL)
  x$summary_df
})
summary_list_ml <- Filter(Negate(is.null), summary_list_ml)

final_summary_ml <- bind_rows(summary_list_ml)

cat("\n\n================= FINAL SUMMARY OVER ALL AGE TARGETS =================\n")
print(final_summary_ml)

write.csv(final_summary_ml, "movielens_transising_age_results.csv", row.names = FALSE)





### movielens plots ###

library(dplyr)
library(rlang)
library(ggplot2)
library(tidyr)

plot_relative_errors <- function(df,
                                 group_col = "Target",
                                 baseline_model = "Naive-LogLasso",
                                 title = "Comparison of Model Performance (Relative to Naive)") {
  gsym <- sym(group_col)
  
  rel_df <- df %>%
    group_by(!!gsym) %>%
    mutate(
      BaseMean = Mean_Error[Model == baseline_model][1],
      BaseSD   = SD_Error[Model == baseline_model][1]
    ) %>%
    ungroup() %>%
    filter(Model != baseline_model) %>%
    mutate(
      Relative_Error = Mean_Error / BaseMean,
      Propagated_SD  = sqrt(
        (SD_Error / pmax(Mean_Error, 1e-12))^2 +
          (BaseSD   / pmax(BaseMean,   1e-12))^2
      ) * Relative_Error
    ) %>%
    filter(is.finite(Relative_Error))
  
  if ("n_target_inv" %in% names(df)) {
    order_levels <- df %>%
      distinct(!!gsym, n_target_inv) %>%
      arrange(desc(n_target_inv)) %>%
      pull(!!gsym)
    rel_df[[group_col]] <- factor(rel_df[[group_col]], levels = order_levels)
  }
  
  ggplot(rel_df, aes(x = !!gsym, y = Relative_Error, fill = Model)) +

    geom_bar(stat = "identity", position = position_dodge(width = 0.9), alpha = 0.85) +
    
    geom_hline(yintercept = 1, linetype = "dashed", color = "red", linewidth = 1) +
    
    labs(
      title = title,
      x = group_col, 
      y = "Relative Misclassification Error",
      fill = "Model"
    ) +
    
    theme_minimal(base_family = "Times New Roman", base_size = 20) +
    theme(
      plot.title = element_text(hjust = 0.5, face = "bold", size = 30),
      
      axis.title = element_text(face = "bold", size = 27),
      axis.text.y = element_text(size = 23),

      axis.text.x = element_text(angle = 45, hjust = 1, face = "bold", size = 27),

      legend.position = "bottom",
      legend.title = element_blank(),
      legend.text = element_text(size = 30),
      legend.key.size = unit(1.0, "cm"),
      legend.margin = margin(t = 20)
    )
}

p_ml <- plot_relative_errors(
  final_summary_ml %>% dplyr::mutate(Target = factor(Target, levels = sort(unique(Target)))),
  group_col = "Target",
  title = "Comparison of Model Performance (MovieLens)"
)

print(p_ml)

write.csv(final_summary_ml, "movielens_transising_age_results.csv", row.names = FALSE)

output_filename <- "realdata_movielens.png"

ggsave(
  filename = output_filename,
  plot = p_ml,
  width = 40,    
  height = 30,    
  units = "cm",
  dpi = 600      
)

cat(sprintf("Saved MovieLens plot to '%s'\n", output_filename))


res_25_new <- run_trans_ising_for_target(
  target_age     = 25,
  df_full        = df_full,
  movie_cols     = movie_cols,
  group_col      = group_col,
  top_p_movies   = top_p_movies,
  min_users_aux  = min_users_aux,
  k_folds        = k_folds,
  lambda_mult    = 0.5  
)

final_summary_ml <- final_summary_ml %>%
  dplyr::filter(Target != 25) %>%
  dplyr::bind_rows(res_25_new$summary_df)

all_results[["25"]] <- res_25_new

p_ml <- plot_relative_errors(
  final_summary_ml %>%
    dplyr::mutate(Target = factor(Target, levels = sort(unique(Target)))),
  group_col = "Target",
  title = "Comparison of Model Performance"
)
print(p_ml)

res_1 <- run_trans_ising_for_target(
  target_age     = 1,           
  df_full        = df_full,
  movie_cols     = movie_cols,
  group_col      = group_col, 
  top_p_movies   = top_p_movies,
  min_users_aux  = min_users_aux,
  k_folds        = k_folds,
  lambda_mult    = lambda_mult
)

final_summary_ml <- final_summary_ml %>%
  dplyr::filter(Target != 1) %>%      
  dplyr::bind_rows(res_1$summary_df)  

all_results[["1"]] <- res_1

p_ml <- plot_relative_errors(
  final_summary_ml %>%
    dplyr::mutate(Target = factor(Target, levels = sort(unique(Target)))),
  group_col = "Target",
  title = "Comparison of Model Performance"
)
print(p_ml)




### Informative Source for MovieLens ###

library(dplyr)

set.seed(2025)

group_col        <- "Age"
top_p_movies     <- 100
min_users_target <- 200
min_users_aux    <- 200

age_summary <- df_full %>%
  group_by(Age) %>%
  summarise(n_users = n(), .groups = "drop") %>%
  arrange(desc(n_users))

cat("=== Age group sizes ===\n")
print(age_summary)

target_ages <- age_summary %>%
  filter(n_users >= min_users_target) %>%
  pull(Age)

cat("\n=== Target age groups (n_users >= ", min_users_target, ") ===\n", sep = "")
print(target_ages)

run_source_detection_for_age <- function(target_age,
                                         df_full,
                                         movie_cols,
                                         group_col     = "Age",
                                         top_p_movies  = 100,
                                         min_users_aux = 200) {
  cat("\n\n############################################################\n")
  cat(sprintf("###  RUNNING SOURCE DETECTION FOR TARGET: %s = %s  ###\n",
              group_col, as.character(target_age)))
  cat("############################################################\n")
  
  df_target <- df_full %>%
    filter(.data[[group_col]] == target_age)
  n_target <- nrow(df_target)
  if (n_target == 0) {
    cat("  -> No users in this target age. Skipping.\n")
    return(NULL)
  }
  
  cat(sprintf("  Target age %s: %d users\n",
              as.character(target_age), n_target))
  
  X_target_full <- df_target[, movie_cols, drop = FALSE] %>%
    as.matrix()
  rownames(X_target_full) <- df_target$UserID
  
  movie_freq_target <- colSums(X_target_full)
  sorted_movies     <- sort(movie_freq_target, decreasing = TRUE)
  top_movies        <- names(sorted_movies)[1:min(top_p_movies, length(sorted_movies))]
  p <- length(top_movies)
  if (p < 2) {
    cat(sprintf("  -> Too few movies selected (p = %d). Skipping.\n", p))
    return(NULL)
  }
  cat(sprintf("  -> Selected top %d movies (p = %d nodes)\n",
              top_p_movies, p))
  
  x_0_full <- X_target_full[, top_movies, drop = FALSE]
  
  group_counts <- df_full %>%
    group_by(.data[[group_col]]) %>%
    summarise(n_users = n(), .groups = "drop") %>%
    arrange(desc(n_users))
  
  aux_groups <- group_counts %>%
    filter(.data[[group_col]] != target_age,
           n_users >= min_users_aux) %>%
    pull(.data[[group_col]])
  
  cat("\n  Auxiliary age groups (candidates):\n")
  print(aux_groups)
  
  if (length(aux_groups) == 0) {
    cat("  -> No auxiliary age groups with enough users. Returning NA row.\n")
    return(
      data.frame(
        TargetAge         = target_age,
        InformativeSource = NA,
        stringsAsFactors  = FALSE
      )
    )
  }
  
  x_A_list_full <- lapply(aux_groups, function(g) {
    df_aux <- df_full %>%
      filter(.data[[group_col]] == g)
    
    X_aux_full <- df_aux[, movie_cols, drop = FALSE] %>%
      as.matrix()
    
    common <- intersect(top_movies, colnames(X_aux_full))
    M_sub  <- matrix(0L, nrow = nrow(X_aux_full), ncol = length(top_movies))
    colnames(M_sub) <- top_movies
    
    if (length(common) > 0) {
      M_sub[, common] <- X_aux_full[, common, drop = FALSE]
    }
    
    rownames(M_sub) <- df_aux$UserID
    M_sub
  })
  names(x_A_list_full) <- as.character(aux_groups)
  
  n_0 <- nrow(x_0_full)
  n_A <- sum(sapply(x_A_list_full, nrow))
  
  if (target_age %in% c(1, 18, 35, 45, 50, 56)) {
    lambda_j_val <- NULL
    cat("\n  Using lambda_j = NULL (internal default in trans_loglasso)\n")
  } else if (target_age %in% c(25)) {
    lambda_j_val <- NULL
    cat(sprintf("\n  Using lambda_mult = 0.5 → lambda_j = %.4g\n", lambda_j_val))
  } else {
    lambda_j_val <- 1.0 * sqrt(log(p) / (n_0 + n_A))
    cat(sprintf("\n  Using default lambda_mult = 1.0 → lambda_j = %.4g\n",
                lambda_j_val))
  }
  
  trans_fun <- if (target_age == 25) trans_loglasso_2 else trans_loglasso
  
  cat("  Running source detection...\n")
  fit_trans <- trans_fun(
    x_A_list = x_A_list_full,
    x_0      = x_0_full,
    lambda_j = lambda_j_val
  )
  
  informative_idx <- fit_trans$informative_set
  
  if (length(informative_idx) == 0) {
    cat("  -> No informative sources detected for this target.\n")
    out <- data.frame(
      TargetAge         = target_age,
      InformativeSource = NA,
      stringsAsFactors  = FALSE
    )
    return(out)
  }
  
  informative_ages <- aux_groups[informative_idx]
  cat(sprintf("  -> Found %d informative source ages: %s\n",
              length(informative_ages),
              paste(informative_ages, collapse = ", ")))
  
  out <- data.frame(
    TargetAge         = target_age,
    InformativeSource = informative_ages,
    stringsAsFactors  = FALSE
  )
  return(out)
}


informative_age_results <- list()

for (ag in target_ages) {
  res <- run_source_detection_for_age(
    target_age     = ag,
    df_full        = df_full,
    movie_cols     = movie_cols,
    group_col      = group_col,
    top_p_movies   = top_p_movies,
    min_users_aux  = min_users_aux
  )
  informative_age_results[[as.character(ag)]] <- res
}

final_source_age_df <- bind_rows(informative_age_results)
rownames(final_source_age_df) <- NULL

cat("\n\n================= INFORMATIVE SOURCE AGE GROUPS =================\n")
print(final_source_age_df)

write.csv(final_source_age_df, "movielens_informative_source_ages_custom_lambda.csv",
          row.names = FALSE)


res_35_new <- run_source_detection_for_age(
  target_age     = 18,
  df_full        = df_full,
  movie_cols     = movie_cols,
  group_col      = group_col,
  top_p_movies   = top_p_movies,
  min_users_aux  = min_users_aux
)

informative_age_results[["25"]] <- res_35_new

final_source_age_df <- bind_rows(informative_age_results)
rownames(final_source_age_df) <- NULL

cat("\n\n================= INFORMATIVE SOURCE AGE GROUPS (UPDATED) =================\n")
print(final_source_age_df)
