
###############################################
### Real data analysis - Online Retail Data ###
###############################################

library(readxl)
library(dplyr)
library(tidyr)
library(readxl)

set.seed(2025)

# Download the dataset from https://archive.ics.uci.edu/dataset/352/online+retail
retail_raw <- read_excel("Online Retail.xlsx")

str(retail_raw)

retail_clean <- retail_raw %>%
  filter(!is.na(InvoiceNo),
         !is.na(StockCode),
         !is.na(Country)) %>%
  mutate(
    InvoiceNo = as.character(InvoiceNo) 
  ) %>%
  filter(
    !grepl("^C", InvoiceNo),  
    Quantity > 0              
  )

cat(sprintf("Cleaned data: %d rows\n", nrow(retail_clean)))

item_counts <- retail_clean %>%
  count(StockCode, sort = TRUE)

min_item_freq <- 50

frequent_items <- item_counts %>%
  filter(n >= min_item_freq) %>%
  pull(StockCode)

length(frequent_items)
cat(sprintf("Number of frequent items (>= %d orders): %d\n",
            min_item_freq, length(frequent_items)))

retail_filtered <- retail_clean %>%
  filter(StockCode %in% frequent_items)


invoice_item <- retail_filtered %>%
  distinct(InvoiceNo, Country, StockCode) %>%  
  mutate(purchased = 1L) %>%
  select(InvoiceNo, Country, StockCode, purchased) %>%
  pivot_wider(
    names_from  = StockCode,
    values_from = purchased,
    values_fill = list(purchased = 0L)
  )

invoice_meta <- invoice_item %>%
  select(InvoiceNo, Country)

X_full <- invoice_item %>%
  select(-InvoiceNo, -Country) %>%
  as.matrix()

rownames(X_full) <- invoice_item$InvoiceNo

cat(sprintf("Full binary matrix: %d invoices × %d items\n",
            nrow(X_full), ncol(X_full)))

country_counts <- invoice_meta %>%
  count(Country, sort = TRUE)

print(country_counts)



### Single target experiment for Online Retail ###

target_country   <- "Germany"   
top_p_items      <- 200        
min_invoices_aux <- 50         

idx_target <- which(invoice_meta$Country == target_country)
n_target   <- length(idx_target)

if (n_target == 0) {
  stop(sprintf("No invoices found for target country: %s", target_country))
}

X_target_full <- X_full[idx_target, , drop = FALSE]

cat(sprintf("Target country '%s': %d invoices\n",
            target_country, n_target))

item_freq_target <- colSums(X_target_full)

sorted_items <- sort(item_freq_target, decreasing = TRUE)
top_items <- names(sorted_items)[1:min(top_p_items, length(sorted_items))]

p <- length(top_items)
cat(sprintf("Selected top %d items for target (p = %d nodes)\n",
            top_p_items, p))

X0 <- X_target_full[, top_items, drop = FALSE]

aux_countries <- invoice_meta %>%
  mutate(Country = as.character(Country)) %>%
  group_by(Country) %>%
  summarise(n_invoices = n(), .groups = "drop") %>%
  filter(Country != target_country,
         n_invoices >= min_invoices_aux) %>%
  arrange(desc(n_invoices)) %>%
  pull(Country)

cat("Auxiliary countries (candidates):\n")
print(aux_countries)

x_A_list <- lapply(aux_countries, function(cty) {
  idx <- which(invoice_meta$Country == cty)
  M   <- X_full[idx, , drop = FALSE]
  common    <- intersect(top_items, colnames(M))
  M_sub     <- matrix(0L, nrow = nrow(M), ncol = length(top_items))
  colnames(M_sub) <- top_items
  
  if (length(common) > 0) {
    M_sub[, common] <- M[, common, drop = FALSE]
  }
  M_sub
})

names(x_A_list) <- aux_countries

x_A_list <- x_A_list[names(x_A_list) != "United Kingdom"]

names(x_A_list)

cat("\n--- Final matrices ---\n")
cat(sprintf("X0 (target '%s'): %d × %d\n", target_country, nrow(X0), ncol(X0)))
sapply(x_A_list, dim)

k_folds <- 5
n0 <- nrow(X0)
fold_id <- sample(cut(seq_len(n0), breaks = k_folds, labels = FALSE))

results <- matrix(NA_real_, nrow = k_folds, ncol = 3)
colnames(results) <- c("Trans-Ising", "Pooled-Trans-Ising", "Naive-LogLasso")

for (i in seq_len(k_folds)) {
  cat(sprintf("\n===== Fold %d / %d =====\n", i, k_folds))
  
  test_idx  <- which(fold_id == i)
  train_idx <- setdiff(seq_len(n0), test_idx)
  
  X0_train <- X0[train_idx, , drop = FALSE]
  X0_test  <- X0[test_idx,  , drop = FALSE]
  
  p   <- ncol(X0_train)                 
  n_0 <- nrow(X0_train)                 
  n_A <- sum(sapply(x_A_list, nrow))    
  
  lambda_j_val <- 0.5 * sqrt(log(p) / (n_0 + n_A))
  
  cat("Training Trans-Ising...\n")
  fit_trans  <- trans_loglasso(x_A_list = x_A_list, x_0 = X0_train, lambda_j = lambda_j_val)
  
  cat("Training Pooled-Trans-Ising...\n")
  fit_pooled <- oracle_trans_loglasso(x_A_list = x_A_list, x_0 = X0_train, lambda_j = lambda_j_val)
  
  cat("Training Naive-LogLasso...\n")
  fit_naive  <- naive_loglasso(x_0 = X0_train)
  
  err_trans  <- misclassification_error(X0_test, fit_trans$beta_hat)
  err_pooled <- misclassification_error(X0_test, fit_pooled$beta_hat)
  err_naive  <- misclassification_error(X0_test, fit_naive)
  
  results[i, ] <- c(err_trans, err_pooled, err_naive)
  
  cat(sprintf("Fold %d: Trans=%.4f, Pooled=%.4f, Naive=%.4f\n",
              i, err_trans, err_pooled, err_naive))
}

colMeans(results, na.rm = TRUE)



### Experiment for all target countries (except UK) ###

library(dplyr)

top_p_items         <- 200   
min_invoices_target <- 50    
min_invoices_aux    <- 50    
k_folds             <- 5
lambda_mult         <- 0.5   

exclude_as_aux    <- c("United Kingdom")
exclude_as_target <- c("United Kingdom")


cat("=== Country sample sizes ===\n")
print(country_counts)

target_countries <- country_counts %>%
  filter(n >= min_invoices_target,
         !Country %in% exclude_as_target) %>%
  pull(Country)

cat("\n=== Target countries (n >= ", min_invoices_target, ") ===\n", sep = "")
print(target_countries)

run_trans_ising_for_country <- function(target_country,
                                        invoice_meta,
                                        X_full,
                                        country_counts,
                                        top_p_items      = 200,
                                        min_invoices_aux = 50,
                                        k_folds          = 5,
                                        lambda_mult      = 0.5,
                                        exclude_as_aux   = c("United Kingdom")) {
  
  cat("\n\n############################################################\n")
  cat(sprintf("###   STARTING EXPERIMENT FOR TARGET COUNTRY: %s   ###\n",
              target_country))
  cat("############################################################\n")
  
  idx_target <- which(invoice_meta$Country == target_country)
  n_target   <- length(idx_target)
  
  if (n_target == 0) {
    cat("  -> No invoices for this target. Skipping.\n")
    return(NULL)
  }
  
  X_target_full <- X_full[idx_target, , drop = FALSE]
  cat(sprintf("  [Target] '%s': %d invoices (rows)\n",
              target_country, n_target))
  
  item_freq_target <- colSums(X_target_full)
  sorted_items     <- sort(item_freq_target, decreasing = TRUE)
  top_items        <- names(sorted_items)[1:min(top_p_items, length(sorted_items))]
  
  p <- length(top_items)
  if (p < 2) {
    cat(sprintf("  -> Too few items selected (p = %d). Skipping.\n", p))
    return(NULL)
  }
  
  cat(sprintf("  -> Selected top %d items (p = %d nodes)\n",
              top_p_items, p))
  
  X0 <- X_target_full[, top_items, drop = FALSE]
  n0 <- nrow(X0)
  cat(sprintf("  -> Final X0: %d invoices × %d items\n", n0, p))
  
  aux_countries <- country_counts %>%
    filter(Country != target_country,
           !Country %in% exclude_as_aux,
           n >= min_invoices_aux) %>%
    arrange(desc(n)) %>%
    pull(Country)
  
  cat("\n  Auxiliary countries (candidates):\n")
  print(aux_countries)
  
  x_A_list <- lapply(aux_countries, function(cty) {
    idx <- which(invoice_meta$Country == cty)
    M   <- X_full[idx, , drop = FALSE]
    
    common <- intersect(top_items, colnames(M))
    M_sub  <- matrix(0L, nrow = nrow(M), ncol = length(top_items))
    colnames(M_sub) <- top_items
    
    if (length(common) > 0) {
      M_sub[, common] <- M[, common, drop = FALSE]
    }
    M_sub
  })
  names(x_A_list) <- aux_countries
  
  if (length(x_A_list) > 0) {
    aux_dims <- sapply(x_A_list, function(M) paste0(nrow(M), "×", ncol(M)))
    cat("\n  x_A_list dimensions (rows×cols):\n")
    print(aux_dims)
  } else {
    cat("  -> No auxiliary countries with enough invoices. Only Naive will be meaningful.\n")
  }
  
  set.seed(2025)  
  k_folds <- min(k_folds, n0)  
  fold_id <- sample(cut(seq_len(n0), breaks = k_folds, labels = FALSE))
  
  results <- matrix(NA_real_, nrow = k_folds, ncol = 3)
  colnames(results) <- c("Trans-Ising", "Pooled-Trans-Ising", "Naive-LogLasso")
  
  for (i in seq_len(k_folds)) {
    cat(sprintf("\n----- Fold %d / %d -----\n", i, k_folds))
    
    test_idx  <- which(fold_id == i)
    train_idx <- setdiff(seq_len(n0), test_idx)
    
    X0_train <- X0[train_idx, , drop = FALSE]
    X0_test  <- X0[test_idx,  , drop = FALSE]
    
    cat(sprintf("  [Fold %d] X0_train: %d×%d, X0_test: %d×%d\n",
                i, nrow(X0_train), ncol(X0_train),
                nrow(X0_test),  ncol(X0_test)))
    
    p   <- ncol(X0_train)
    n_0 <- nrow(X0_train)
    n_A <- if (length(x_A_list) > 0) sum(sapply(x_A_list, nrow)) else 0L
    lambda_j_val <- lambda_mult * sqrt(log(p) / (n_0 + max(n_A, 1)))
    
    tryCatch({
      cat("  Training Naive-LogLasso...\n")
      theta_naive <- naive_loglasso(x_0 = X0_train)
      
      if (length(x_A_list) > 0) {
        cat("  Training Trans-Ising...\n")
        fit_trans <- trans_loglasso(
          x_A_list = x_A_list,
          x_0      = X0_train,
          lambda_j = lambda_j_val
        )
        theta_trans <- fit_trans$beta_hat
        
        cat("  Training Pooled-Trans-Ising...\n")
        fit_pooled <- oracle_trans_loglasso(
          x_A_list = x_A_list,
          x_0      = X0_train,
          lambda_j = lambda_j_val
        )
        theta_pooled <- fit_pooled$beta_hat
      } else {
        theta_trans  <- NULL
        theta_pooled <- NULL
      }
      
      cat("  Evaluating on test set...\n")
      err_naive <- misclassification_error(X0_test, theta_naive)
      if (!is.null(theta_trans)) {
        err_trans  <- misclassification_error(X0_test, theta_trans)
        err_pooled <- misclassification_error(X0_test, theta_pooled)
      } else {
        err_trans  <- NA_real_
        err_pooled <- NA_real_
      }
      
      results[i, ] <- c(err_trans, err_pooled, err_naive)
      
      cat(sprintf("  Fold %d Errors: Trans=%.4f, Pooled=%.4f, Naive=%.4f\n",
                  i, err_trans, err_pooled, err_naive))
      
    }, error = function(e) {
      cat(sprintf("  [Fold %d] ERROR: %s\n", i, e$message))
    })
  }
  
  avg <- colMeans(results, na.rm = TRUE)
  sdv <- apply(results, 2, sd, na.rm = TRUE)
  
  summary_df <- data.frame(
    Model        = names(avg),
    Mean_Error   = as.numeric(avg),
    SD_Error     = as.numeric(sdv),
    Target       = target_country,
    n_target_inv = n0,
    p            = p
  )
  
  cat("\n=== Summary for target country:", target_country, "===\n")
  print(summary_df)
  
  return(list(
    target       = target_country,
    summary_df   = summary_df,
    fold_results = results,
    aux_countries = aux_countries
  ))
}


all_results <- list()

for (cty in target_countries) {
  res <- run_trans_ising_for_country(
    target_country   = cty,
    invoice_meta     = invoice_meta,
    X_full           = X_full,
    country_counts   = country_counts,
    top_p_items      = top_p_items,
    min_invoices_aux = min_invoices_aux,
    k_folds          = k_folds,
    lambda_mult      = lambda_mult,
    exclude_as_aux   = exclude_as_aux
  )
  all_results[[cty]] <- res
}

summary_list_retail <- lapply(all_results, function(x) {
  if (is.null(x)) return(NULL)
  x$summary_df
})
summary_list_retail <- Filter(Negate(is.null), summary_list_retail)

final_summary_retail <- bind_rows(summary_list_retail)

cat("\n\n================= FINAL SUMMARY OVER ALL TARGET COUNTRIES =================\n")
print(final_summary_retail)

write.csv(final_summary_retail, "online_retail_transising_country_results.csv", row.names = FALSE)




### Plots for online retail data ###

library(dplyr)
library(ggplot2)
library(rlang) 

plot_relative_errors <- function(df,
                                 group_col = "Target",
                                 baseline_model = "Naive-LogLasso",
                                 title = "Comparison of Model Performance") {
  
  gsym <- sym(group_col)
  
  rel_df <- df %>%
    group_by(!!gsym) %>%
    mutate(
      BaseMean = Mean_Error[Model == baseline_model][1],
      BaseSD   = SD_Error[Model == baseline_model][1]
    ) %>%
    ungroup() %>%
    filter(Model != baseline_model) %>%
    mutate(
      Relative_Error = Mean_Error / BaseMean,
      Propagated_SD  = sqrt(
        (SD_Error / pmax(Mean_Error, 1e-12))^2 +
          (BaseSD   / pmax(BaseMean,   1e-12))^2
      ) * Relative_Error
    ) %>%
    filter(is.finite(Relative_Error), is.finite(Propagated_SD))
  
  if ("n_target_inv" %in% names(df)) {
    order_levels <- df %>%
      distinct(!!gsym, n_target_inv) %>%
      arrange(desc(n_target_inv)) %>%
      pull(!!gsym)
    rel_df[[group_col]] <- factor(rel_df[[group_col]], levels = order_levels)
  }
  
  ggplot(rel_df, aes(x = !!gsym, y = Relative_Error, fill = Model)) +
    
    geom_bar(stat = "identity", position = position_dodge(width = 0.9), alpha = 0.85) +

    geom_hline(yintercept = 1, linetype = "dashed", color = "red", size = 1) +
    
    labs(
      title = title,
      x = group_col, 
      y = "Relative Misclassification Error",
      fill = "Model"
    ) +
    
    theme_minimal(base_family = "Times New Roman", base_size = 20) +
    theme(
      plot.title = element_text(hjust = 0.5, face = "bold", size = 30),
      
      axis.title = element_text(face = "bold", size = 27),
      axis.text.y = element_text(size = 23),
      axis.text.x = element_text(angle = 45, hjust = 1, face = "bold", size = 27),
      
      legend.position = "bottom",
      legend.title = element_blank(),        
      legend.text = element_text(size = 30), 
      legend.key.size = unit(1.0, "cm"),     
      legend.margin = margin(t = 20)         
    )
}

p_retail <- plot_relative_errors(final_summary_retail,
                                 group_col = "Target",
                                 title = "Comparison of Model Performance (Retail)")

output_filename <- "realdata_retail.png"

ggsave(
  filename = output_filename,
  plot = p_retail,
  width = 40,    
  height = 30,    
  units = "cm",
  dpi = 600      
)

cat(sprintf("Saved Retail plot to '%s'\n", output_filename))



### Informative source for online retail data ###

library(dplyr)

set.seed(2025)

top_p_items        <- 200   
min_invoices_target <- 50   
min_invoices_aux    <- 50   
drop_uk_as_source   <- TRUE 

country_counts <- invoice_meta %>%
  group_by(Country) %>%
  summarise(n_invoices = n(), .groups = "drop") %>%
  arrange(desc(n_invoices))

cat("=== Country sample sizes ===\n")
print(country_counts)

target_countries <- country_counts %>%
  filter(n_invoices >= min_invoices_target) %>%
  pull(Country)

target_countries <- setdiff(target_countries, "United Kingdom")

cat("\n=== Target countries (n_invoices >= ", min_invoices_target, ") ===\n", sep = "")
print(target_countries)

informative_country_results <- list()

for (current_country in target_countries) {
  
  cat("\n\n############################################################\n")
  cat(sprintf("###  STARTING SOURCE DETECTION FOR TARGET COUNTRY: %s  ###\n",
              current_country))
  cat("############################################################\n")
  
  idx_target <- which(invoice_meta$Country == current_country)
  n_target   <- length(idx_target)
  
  if (n_target < min_invoices_target) {
    cat(sprintf("  -> Too few invoices for %s (n = %d). Skipping.\n",
                current_country, n_target))
    next
  }
  
  cat(sprintf("Target '%s': %d invoices\n", current_country, n_target))
  
  X_target_full <- X_full[idx_target, , drop = FALSE]
  rownames(X_target_full) <- invoice_meta$InvoiceNo[idx_target]
  
  item_freq_target <- colSums(X_target_full)
  sorted_items     <- sort(item_freq_target, decreasing = TRUE)
  top_items        <- names(sorted_items)[1:min(top_p_items, length(sorted_items))]
  p <- length(top_items)
  
  if (p < 2) {
    cat(sprintf("  -> Too few items selected for %s (p = %d). Skipping.\n",
                current_country, p))
    next
  }
  
  cat(sprintf("  -> Selected top %d items (p = %d nodes)\n",
              top_p_items, p))
  
  x_0_full <- X_target_full[, top_items, drop = FALSE]
  
  aux_country_counts <- country_counts %>%
    filter(Country != current_country,
           n_invoices >= min_invoices_aux)
  
  if (drop_uk_as_source) {
    aux_country_counts <- aux_country_counts %>%
      filter(Country != "United Kingdom")
  }
  
  aux_countries <- aux_country_counts$Country
  
  cat("\n  Auxiliary countries (candidates):\n")
  print(aux_countries)
  
  if (length(aux_countries) == 0) {
    cat("  -> No auxiliary countries with enough invoices. Skipping Trans-Ising.\n")
    
    informative_country_results[[current_country]] <- data.frame(
      TargetCountry      = current_country,
      InformativeSource  = NA_character_,
      Loss_Difference    = NA_real_,
      n_target_invoices  = n_target,
      p_nodes            = p
    )
    next
  }
  
  x_A_list_full <- lapply(aux_countries, function(cty) {
    idx <- which(invoice_meta$Country == cty)
    M   <- X_full[idx, , drop = FALSE]
    
    common <- intersect(top_items, colnames(M))
    M_sub  <- matrix(0L, nrow = nrow(M), ncol = length(top_items))
    colnames(M_sub) <- top_items
    
    if (length(common) > 0) {
      M_sub[, common] <- M[, common, drop = FALSE]
    }
    
    rownames(M_sub) <- invoice_meta$InvoiceNo[idx]
    M_sub
  })
  names(x_A_list_full) <- aux_countries
  
  cat("\n  x_A_list_full dimensions (rows×cols):\n")
  aux_dims <- sapply(x_A_list_full, function(M) paste0(nrow(M), "×", ncol(M)))
  print(aux_dims)
  
  n_0 <- nrow(x_0_full)                         
  n_A <- if (length(x_A_list_full) > 0) {
    sum(sapply(x_A_list_full, nrow))            
  } else {
    0L
  }
  
  lambda_j_val <- 0.5 * sqrt(log(p) / (n_0 + max(n_A, 1)))
  
  cat(sprintf("\n### Running trans_loglasso for target '%s'... ###\n",
              current_country))
  
  source_detection_run <- trans_loglasso(
    x_A_list = x_A_list_full,
    x_0      = x_0_full,
    lambda_j = lambda_j_val
  )
  
  informative_indices <- source_detection_run$informative_set
  
  if (length(informative_indices) > 0) {
    
    informative_countries <- aux_countries[informative_indices]
    loss_diffs            <- source_detection_run$loss_diffs[informative_indices]
    
    cat(sprintf("### Found %d informative sources for %s: %s ###\n",
                length(informative_countries),
                current_country,
                paste(informative_countries, collapse = ", ")))
    
    informative_country_results[[current_country]] <- data.frame(
      TargetCountry      = current_country,
      InformativeSource  = informative_countries,
      Loss_Difference    = loss_diffs,
      n_target_invoices  = n_target,
      p_nodes            = p
    )
    
  } else {
    cat(sprintf("### Found 0 informative sources for %s ###\n", current_country))
    
    informative_country_results[[current_country]] <- data.frame(
      TargetCountry      = current_country,
      InformativeSource  = NA_character_,
      Loss_Difference    = NA_real_,
      n_target_invoices  = n_target,
      p_nodes            = p
    )
  }
}

final_source_country_df <- do.call(rbind, informative_country_results)
rownames(final_source_country_df) <- NULL

cat("\n\n==================== INFORMATIVE SOURCE COUNTRIES ====================\n")
print(final_source_country_df)


write.csv(final_source_country_df,
          "informative_source_countries_online_retail.csv",
          row.names = FALSE)

