# analyze_and_plot.R
# Unified analysis script: data processing, statistical analysis, plotting, LaTeX table generation

library(ggplot2)
library(dplyr)
library(tidyr)
library(patchwork)
library(xtable)

# ==============================================================================
# Command Line Arguments
# ==============================================================================

args <- commandArgs(trailingOnly = TRUE)
n_small <- ifelse(length(args) >= 1, as.numeric(args[1]), 100)

# ==============================================================================
# Configuration
# ==============================================================================

coef_true <- c(0.15, -0.15, 0.3, 0.3, 0.3, 0.3)
k_values <- c(6, 20, 50, 200)

# Method configuration: coef_cols, se_cols, n_basis_col, n_used_col
# n_used_col: column index for number of batches used (for meta/online_bc)
method_config <- list(
  oracle = list(coef_cols = 3:8, se_cols = 9:14, n_basis_col = NULL, n_used_col = NULL),
  colsa = list(coef_cols = 4:9, se_cols = 10:15, n_basis_col = 3, n_used_col = NULL),
  meta = list(coef_cols = 4:9, se_cols = 10:15, n_basis_col = NULL, n_used_col = 3),
  online_bc = list(coef_cols = 5:10, se_cols = 11:16, n_basis_col = 4, n_used_col = 3),
  sgd = list(coef_cols = 3:8, se_cols = 9:14, n_basis_col = NULL, n_used_col = NULL),
  sgd_offline = list(coef_cols = 3:8, se_cols = 9:14, n_basis_col = NULL, n_used_col = NULL)
)

# Visualization configuration (colorblind-friendly)
method_colors <- c(
  oracle = "#000000", colsa = "#0072B2", meta = "#56B4E9",
  online_bc = "#D55E00", sgd = "#CC79A7", sgd_offline = "#009E73"
)
method_shapes <- c(oracle = 16, colsa = 17, meta = 15, online_bc = 18, sgd = 8, sgd_offline = 4)
method_linetypes <- c(
  oracle = "solid", colsa = "dashed", meta = "dotted",
  online_bc = "dotdash", sgd = "longdash", sgd_offline = "twodash"
)
method_labels <- c(
  oracle = "Oracle", colsa = "COLSA", meta = "Meta",
  online_bc = "Online", sgd = "SGD", sgd_offline = "SGD (Offline)"
)

# ==============================================================================
# Helper Functions
# ==============================================================================

#' Load and preprocess data for a single method
#' @param K Number of batches
#' @param method Method name
#' @return list(idx, elapsed, coef, se, n_basis) or NULL
load_method_data <- function(K, method) {
  file_path <- sprintf("results/sim_K=%d_n=%d_method=%s.csv", K, n_small, method)
  if (!file.exists(file_path)) {
    cat("File not found:", file_path, "\n")
    return(NULL)
  }

  config <- method_config[[method]]
  df <- read.csv(file_path)
  n_total <- nrow(df)

  # 1. Remove NA values
  df <- df[complete.cases(df[, c(config$coef_cols, config$se_cols)]), ]

  if (nrow(df) == 0) {
    cat("No valid data for", method, "K =", K, "\n")
    return(NULL)
  }

  list(
    idx = df[, 1],
    elapsed = df[, 2],
    coef = as.matrix(df[, config$coef_cols]),
    se = as.matrix(df[, config$se_cols]),
    n_basis = if (!is.null(config$n_basis_col)) df[, config$n_basis_col] else NULL,
    n_used = if (!is.null(config$n_used_col)) df[, config$n_used_col] else NULL
  )
}

#' Compute statistical metrics
#' @param coef Coefficient matrix (n_rep x n_coef)
#' @param se Standard error matrix (n_rep x n_coef)
#' @param coef_true True coefficient vector
#' @param baseline Baseline coefficient matrix (for computing ARP)
#' @param scale Whether to apply scaling
#' @return data.frame with arp, cp, mse, ase, ese, bias for each coefficient
compute_stats <- function(coef, se, coef_true, baseline, scale = TRUE) {
  bias <- colMeans(coef) - coef_true
  ese <- apply(coef, 2, sd)
  mse <- bias^2 + ese^2
  ase <- colMeans(se)

  lower <- coef - 1.96 * se
  upper <- coef + 1.96 * se
  cp <- colMeans(sweep(lower, 2, coef_true, "<=") &
    sweep(upper, 2, coef_true, ">="))
  arp <- colMeans(abs((coef - baseline) / baseline))

  scale_factor <- if (scale) {
    c(bias = 100, ese = 100, ase = 100, mse = 1e4, cp = 100, arp = 100)
  } else {
    rep(1, 6)
  }

  data.frame(
    arp = arp * scale_factor["arp"],
    cp = cp * scale_factor["cp"],
    mse = mse * scale_factor["mse"],
    ase = ase * scale_factor["ase"],
    ese = ese * scale_factor["ese"],
    bias = bias * scale_factor["bias"]
  )
}

#' Format numbers
fmt <- function(x, digits = 1) sprintf(paste0("%.", digits, "f"), x)

# ==============================================================================
# Data Loading
# ==============================================================================

cat("===== Loading Data =====\n")

# Store all data
all_data <- list()

for (K in k_values) {
  for (method in names(method_config)) {
    data <- load_method_data(K, method)
    if (!is.null(data)) {
      all_data[[paste0(method, "_K", K)]] <- data
    }
  }
}

# ==============================================================================
# 1. Plot Data Preparation
# ==============================================================================

cat("\n===== Preparing Plot Data =====\n")

plot_results <- list()

for (K in k_values) {
  for (method in names(method_config)) {
    key <- paste0(method, "_K", K)
    if (!(key %in% names(all_data))) next

    data <- all_data[[key]]
    coef_mat <- data$coef
    se_mat <- data$se

    # Abias: Average of absolute bias across all coefficients
    bias <- sweep(coef_mat, 2, coef_true, "-")
    abias <- mean(colMeans(abs(bias)))

    # RMSE: sqrt(sum of MSE for each beta)
    ese <- apply(coef_mat, 2, sd)
    mse_per_beta <- colMeans(bias)^2 + ese^2
    rmse <- sqrt(sum(mse_per_beta))

    # Average Coverage Probability
    lower <- coef_mat - 1.96 * se_mat
    upper <- coef_mat + 1.96 * se_mat
    cp_per_beta <- colMeans(sweep(lower, 2, coef_true, "<=") &
                            sweep(upper, 2, coef_true, ">="))
    avg_cp <- mean(cp_per_beta)

    # n_basis and n_used
    n_basis_mean <- if (!is.null(data$n_basis)) mean(data$n_basis) else NA
    n_used_mean <- if (!is.null(data$n_used)) mean(data$n_used) else NA

    plot_results[[length(plot_results) + 1]] <- data.frame(
      K = K, Method = method,
      Abias = abias, RMSE = rmse, AvgCP = avg_cp,
      n_basis = n_basis_mean, n_used = n_used_mean
    )
  }
}

plot_df <- bind_rows(plot_results)
plot_df$Method <- factor(plot_df$Method, levels = names(method_config))

cat("\n=== Plot Data ===\n")
print(plot_df)

# ==============================================================================
# 2. Statistical Analysis
# ==============================================================================

cat("\n===== Computing Statistics =====\n")

stats_results <- list()
time_results <- list()

for (K in k_values) {
  oracle_key <- paste0("oracle_K", K)
  if (!(oracle_key %in% names(all_data))) {
    cat("Oracle data not found for K =", K, "\n")
    next
  }

  oracle_data <- all_data[[oracle_key]]

  for (method in names(method_config)) {
    key <- paste0(method, "_K", K)
    if (!(key %in% names(all_data))) next

    data <- all_data[[key]]

    # Align baseline
    baseline <- oracle_data$coef[match(data$idx, oracle_data$idx), , drop = FALSE]

    # Compute statistical metrics
    stats <- compute_stats(data$coef, data$se, coef_true, baseline, scale = TRUE)

    stats_results[[key]] <- stats
    time_results[[key]] <- mean(data$elapsed)
  }
}

# ==============================================================================
# 3. Plotting
# ==============================================================================

cat("\n===== Generating Plots =====\n")

theme_paper <- function(base_size = 11) {
  theme_bw(base_size = base_size) +
    theme(
      plot.title = element_text(
        size = base_size + 1, face = "bold", hjust = 0.5,
        margin = margin(b = 8)
      ),
      axis.title.x = element_text(size = base_size, margin = margin(t = 8)),
      axis.title.y = element_text(size = base_size, margin = margin(r = 8)),
      axis.text = element_text(size = base_size - 1, color = "black"),
      legend.title = element_blank(),
      legend.text = element_text(size = base_size - 1),
      legend.key.width = unit(1.5, "lines"),
      legend.key.height = unit(0.9, "lines"),
      legend.spacing.x = unit(0.3, "lines"),
      legend.background = element_rect(fill = "white", color = NA),
      legend.margin = margin(t = 2, b = 2, l = 4, r = 4),
      panel.grid.minor = element_blank(),
      panel.grid.major = element_line(color = "grey85", linewidth = 0.4),
      panel.border = element_rect(color = "black", linewidth = 0.6),
      plot.margin = margin(t = 8, r = 12, b = 8, l = 8)
    )
}

# Plot 1: Abias vs K
p1 <- ggplot(plot_df, aes(
  x = K, y = Abias, color = Method,
  shape = Method, linetype = Method
)) +
  geom_line(linewidth = 0.8) +
  geom_point(size = 2.5) +
  scale_color_manual(values = method_colors, labels = method_labels) +
  scale_shape_manual(values = method_shapes, labels = method_labels) +
  scale_linetype_manual(values = method_linetypes, labels = method_labels) +
  scale_x_continuous(breaks = k_values, labels = k_values) +
  labs(
    title = "(a) Average of Absolute Bias",
    x = expression(italic(K)),
    y = "Absolute Bias"
  ) +
  theme_paper() +
  theme(legend.position = "none")

# Plot 2: RMSE vs K
p2 <- ggplot(plot_df, aes(
  x = K, y = RMSE, color = Method,
  shape = Method, linetype = Method
)) +
  geom_line(linewidth = 0.8) +
  geom_point(size = 2.5) +
  scale_color_manual(values = method_colors, labels = method_labels) +
  scale_shape_manual(values = method_shapes, labels = method_labels) +
  scale_linetype_manual(values = method_linetypes, labels = method_labels) +
  scale_x_continuous(breaks = k_values, labels = k_values) +
  labs(
    title = "(b) Root Mean Squared Error",
    x = expression(italic(K)),
    y = "RMSE"
  ) +
  theme_paper() +
  theme(legend.position = "none")

# Plot 3: Coverage Probability vs K
p3 <- ggplot(plot_df, aes(
  x = K, y = AvgCP, color = Method,
  shape = Method, linetype = Method
)) +
  geom_line(linewidth = 0.8) +
  geom_point(size = 2.5) +
  geom_hline(yintercept = 0.95, linetype = "dashed", color = "grey50", linewidth = 0.5) +
  scale_color_manual(values = method_colors, labels = method_labels) +
  scale_shape_manual(values = method_shapes, labels = method_labels) +
  scale_linetype_manual(values = method_linetypes, labels = method_labels) +
  scale_x_continuous(breaks = k_values, labels = k_values) +
  labs(
    title = "(c) Coverage Probability",
    x = expression(italic(K)),
    y = "Coverage Probability"
  ) +
  theme_paper() +
  theme(legend.position = "none")

# Plot 4: Number of Parameters vs K (COLSA and Online)
# Total parameters = p (6 coefficients) + J (hazard intervals)
df_basis <- plot_df %>%
  filter(Method %in% c("colsa", "online_bc"), !is.na(n_basis)) %>%
  mutate(n_params = 6 + n_basis)

basis_colors <- c(colsa = "#0072B2", online_bc = "#D55E00")
basis_shapes <- c(colsa = 17, online_bc = 18)
basis_linetypes <- c(colsa = "dashed", online_bc = "dotdash")
basis_labels <- c(colsa = "COLSA", online_bc = "Online")

p4 <- ggplot(df_basis, aes(
  x = K, y = n_params, color = Method,
  shape = Method, linetype = Method
)) +
  geom_line(linewidth = 0.8) +
  geom_point(size = 2.5) +
  scale_color_manual(values = basis_colors, labels = basis_labels) +
  scale_shape_manual(values = basis_shapes, labels = basis_labels) +
  scale_linetype_manual(values = basis_linetypes, labels = basis_labels) +
  scale_x_continuous(breaks = k_values, labels = k_values) +
  labs(
    title = "(d) Number of Parameters",
    x = expression(italic(K)),
    y = "Parameters"
  ) +
  theme_paper() +
  theme(legend.position = "none")

# Save individual plots
ggsave(sprintf("results/fig_n=%d_abias.pdf", n_small),
  p1 + theme(legend.position = "bottom", legend.direction = "horizontal") +
    guides(color = guide_legend(nrow = 1)),
  width = 4.5, height = 4, device = cairo_pdf
)

ggsave(sprintf("results/fig_n=%d_rmse.pdf", n_small),
  p2 + theme(legend.position = "bottom", legend.direction = "horizontal") +
    guides(color = guide_legend(nrow = 1)),
  width = 4.5, height = 4, device = cairo_pdf
)

ggsave(sprintf("results/fig_n=%d_avgcp.pdf", n_small),
  p3 + theme(legend.position = "bottom", legend.direction = "horizontal") +
    guides(color = guide_legend(nrow = 1)),
  width = 4.5, height = 4, device = cairo_pdf
)

ggsave(sprintf("results/fig_n=%d_nbasis.pdf", n_small),
  p4, width = 4.5, height = 4, device = cairo_pdf
)

# Combined 2x2 plot
extract_legend <- function(plot) {
  tmp <- ggplot_gtable(ggplot_build(plot))
  leg <- which(sapply(tmp$grobs, function(x) x$name) == "guide-box")
  if (length(leg) > 0) tmp$grobs[[leg]] else NULL
}

legend_plot <- ggplot(plot_df, aes(
  x = K, y = Abias, color = Method,
  shape = Method, linetype = Method
)) +
  geom_line(linewidth = 0.8) +
  geom_point(size = 2.5) +
  scale_color_manual(values = method_colors, labels = method_labels) +
  scale_shape_manual(values = method_shapes, labels = method_labels) +
  scale_linetype_manual(values = method_linetypes, labels = method_labels) +
  theme_paper() +
  theme(legend.position = "bottom", legend.direction = "horizontal") +
  guides(
    color = guide_legend(
      nrow = 1, byrow = TRUE,
      override.aes = list(size = 2.5, linewidth = 0.8)
    ),
    shape = guide_legend(nrow = 1, byrow = TRUE),
    linetype = guide_legend(nrow = 1, byrow = TRUE)
  )

legend_grob <- extract_legend(legend_plot)
p_grid <- (p1 | p2) / (p3 | p4)
p_combined <- p_grid / wrap_elements(legend_grob) + plot_layout(heights = c(10, 10, 1))

ggsave(sprintf("results/fig_n=%d_combined.pdf", n_small), p_combined,
  width = 8, height = 7, device = cairo_pdf
)

cat("Plots saved to results/\n")

# ==============================================================================
# 4. LaTeX Table Generation
# ==============================================================================

cat("\n===== Generating LaTeX Tables =====\n")

methods_for_table <- c("oracle", "colsa", "meta", "online_bc", "sgd", "sgd_offline")
method_labels_table <- c(
  oracle = "Oracle", colsa = "COLSA", meta = "Meta",
  online_bc = "Online", sgd = "SGD", sgd_offline = "SGD (Offline)"
)

# Open file for output
sink(sprintf("results/simulation_n=%d_tables.tex", n_small))

# ----- Table 1: Computation Time -----
cat("% ===== Table 1: Computation Time =====\n")
cat("\\begin{table}[htbp]\\footnotesize
    \\centering
    \\caption{Average computation time (seconds) for varying numbers of sites.}\\label{tab:time}
    \\begin{tabular}{lcccc}
        \\hline
        \\textit{Methods} & K = 6 & K = 20 & K = 50 & K = 200 \\\\
        \\hline
")

for (method in methods_for_table) {
  label <- method_labels_table[method]
  cat(sprintf("        \\textit{%s}", label))
  for (K in k_values) {
    key <- paste0(method, "_K", K)
    if (key %in% names(time_results)) {
      cat(sprintf(" & %s", fmt(time_results[[key]], 3)))
    } else {
      cat(" & -")
    }
  }
  cat(" \\\\\n")
}

cat("        \\hline
    \\end{tabular}
\\end{table}

")

# ----- Table 2: Performance Metrics -----
cat("% ===== Table 2: Performance Metrics =====\n")
cat("\\begin{table*}[]\\footnotesize
    \\centering
    \\caption{Simulation results for $\\bbeta_1$ and $\\bbeta_6$ for varying numbers of sites.}\\label{tab:simulation}
    \\scalebox{0.85}{
        \\begin{tabular}{lccccccccccccccccccc}
            \\hline
            & \\multicolumn{5}{c}{K = 6} & \\multicolumn{5}{c}{K = 20} & \\multicolumn{5}{c}{K = 50} & \\multicolumn{5}{c}{K = 200} \\\\
            \\cline{2-21}
            & \\multicolumn{20}{c}{$\\beta_1$} \\\\
            \\textit{Methods} & \\textit{AARB} & \\textit{CP} & \\textit{MSE} & \\textit{ASE} & \\textit{ESE} & \\textit{AARB} & \\textit{CP} & \\textit{MSE} & \\textit{ASE} & \\textit{ESE} & \\textit{AARB} & \\textit{CP} & \\textit{MSE} & \\textit{ASE} & \\textit{ESE} & \\textit{AARB} & \\textit{CP} & \\textit{MSE} & \\textit{ASE} & \\textit{ESE} \\\\
            \\hline
")

# beta1 rows
for (method in methods_for_table) {
  label <- method_labels_table[method]
  cat(sprintf("            \\textit{%s}", label))
  for (K in k_values) {
    key <- paste0(method, "_K", K)
    if (key %in% names(stats_results)) {
      res <- stats_results[[key]]
      cat(sprintf(
        " & %s & %s & %s & %s & %s",
        fmt(res$arp[1], 1), fmt(res$cp[1], 1), fmt(res$mse[1], 2),
        fmt(res$ase[1], 2), fmt(res$ese[1], 2)
      ))
    } else {
      cat(" & - & - & - & - & -")
    }
  }
  cat(" \\\\\n")
}

# beta6 section
cat("            \\hline
            & \\multicolumn{20}{c}{$\\beta_6$} \\\\
            \\textit{Methods} & \\textit{AARB} & \\textit{CP} & \\textit{MSE} & \\textit{ASE} & \\textit{ESE} & \\textit{AARB} & \\textit{CP} & \\textit{MSE} & \\textit{ASE} & \\textit{ESE} & \\textit{AARB} & \\textit{CP} & \\textit{MSE} & \\textit{ASE} & \\textit{ESE} & \\textit{AARB} & \\textit{CP} & \\textit{MSE} & \\textit{ASE} & \\textit{ESE} \\\\
            \\hline
")

# beta6 rows
for (method in methods_for_table) {
  label <- method_labels_table[method]
  cat(sprintf("            \\textit{%s}", label))
  for (K in k_values) {
    key <- paste0(method, "_K", K)
    if (key %in% names(stats_results)) {
      res <- stats_results[[key]]
      cat(sprintf(
        " & %s & %s & %s & %s & %s",
        fmt(res$arp[6], 1), fmt(res$cp[6], 1), fmt(res$mse[6], 2),
        fmt(res$ase[6], 2), fmt(res$ese[6], 2)
      ))
    } else {
      cat(" & - & - & - & - & -")
    }
  }
  cat(" \\\\\n")
}

cat("            \\hline
        \\end{tabular}}
    \\begin{tablenotes}{
            \\item AARB, average absolute relative bias with respect to the oracle estimate (\\%); CP, coverage probability (\\%); MSE, mean squared error ($\\times 10^{-4}$); ASE, mean estimated standard error ($\\times 10^{-2}$); ESE, empirical standard error ($\\times 10^{-2}$).
        }\\end{tablenotes}
\\end{table*}
")

sink()

cat(sprintf("LaTeX tables saved to results/simulation_n=%d_tables.tex\n", n_small))

# ==============================================================================
# 5. Console Summary Output
# ==============================================================================

cat("\n===== Summary Table =====\n")
summary_table <- plot_df %>%
  arrange(K, Method) %>%
  mutate(
    Abias = sprintf("%.3f", Abias),
    RMSE = sprintf("%.3f", RMSE),
    AvgCP = sprintf("%.3f", AvgCP),
    n_basis = ifelse(is.na(n_basis), "-", sprintf("%.1f", n_basis)),
    n_used = ifelse(is.na(n_used), "-", sprintf("%.1f", n_used))
  )
print(summary_table)

cat("\n===== Analysis Complete =====\n")
