library(pcalg)
library(kpcalg)
library(tidyverse)
library(xgboost)

source('R/experiments/alzheimers/prepare_data.R')
source('R/experiments/alzheimers/tetrad_parser.R')
source('R/scm/graph_functions.R')
source('R/causal_discovery.R')

seed <- 123
set.seed(seed)

merge <- read_csv("R/experiments/alzheimers/") # path to adni data

data <- get_data_base(merge)
vertex_names <- colnames(data)
data <- binarize_apoe(data) # As done by Shen et al.
data <- binarize_diagnosis(data) # As done by Heskes et al.
data <- binarize_sex(data)

n_data <- nrow(data)
train_idxs <- readRDS("R/experiments/alzheimers/train_idxs_0.8.rds")
test_idxs <- setdiff(seq_len(n_data), train_idxs)

y_col_idx <- which(colnames(data) == "DX")
x_observational <- data[, -y_col_idx]
x_train <- data[train_idxs, -y_col_idx]
y_train <- data[train_idxs, y_col_idx, drop = FALSE]
x_to_explain <- data[test_idxs, -y_col_idx]
y_to_explain <- data[test_idxs, y_col_idx, drop = FALSE]

# Train xgboost model
model_to_explain <- xgboost(data = as.matrix(x_train), label = as.matrix(y_train), nround = 100, verbose = FALSE)
pred_to_explain <- predict(model_to_explain, as.matrix(x_to_explain))

# true dag
true_dag <- get_gold_standard_dag(data)

# Calculate Shapley values
devtools::load_all("shapr-new", quiet = TRUE)
prediction_zero <- colMeans(y_train)[[1]]
n_samples_expectation <- 1000
n_combinations_shapr <- NULL
n_batches <- 8

explanation_marginal <- shapr::explain(
  model = model_to_explain,
  x_train = x_train,
  x_explain = x_to_explain,
  prediction_zero = prediction_zero,
  n_samples = n_samples_expectation,
  n_combinations = n_combinations_shapr,
  seed = seed,
  n_batches = n_batches,
  approach = "independence",
)

explanation_conditional <- shapr::explain(
  model = model_to_explain,
  x_train = x_train,
  x_explain = x_to_explain,
  prediction_zero = prediction_zero,
  n_samples = n_samples_expectation,
  n_combinations = n_combinations_shapr,
  seed = seed,
  n_batches = n_batches,
  approach = "gaussian",
)


explanation_causal <- shapr::explain(
  model = model_to_explain,
  x_train = x_train,
  x_explain = x_to_explain,
  n_samples = n_samples_expectation,
  n_combinations = n_combinations_shapr,
  prediction_zero = prediction_zero,
  seed = seed,
  n_batches = n_batches,
  approach = "causal",
  causal_dag = true_dag,
)

# Get these from tetrad_causal_discovery.R
mec_fges_apoe4142 <- readRDS("R/experiments/alzheimers/mec_fges_apoe4142.rds")
mec_pc_g_apoe4142 <- readRDS("R/experiments/alzheimers/mec_pc_g_apoe4142.rds")
causal_discovery_cpdags <- list(
  "FGES" = mec_fges_apoe4142,
  "PC" = mec_pc_g_apoe4142
)

explanations_causal_discovery_sampling <- vector("list", length(causal_discovery_cpdags))
names(explanations_causal_discovery_sampling) <- names(causal_discovery_cpdags)
for (cpdag_idx in seq_along(causal_discovery_cpdags)) {

  igraph_dags <- causal_discovery_cpdags[[cpdag_idx]]
  # try
  try ({
    explanation_cpdag <- shapr::explain(
      model = model_to_explain,
      x_train = x_train,
      x_explain = x_to_explain,
      approach = "mec",
      causal_approximation_method = "sample",
      n_samples = n_samples_expectation,
      n_combinations = n_combinations_shapr,
      prediction_zero = prediction_zero,
      seed = seed,
      dags = igraph_dags,
      n_batches = n_batches
    )

    explanations_causal_discovery_sampling[[cpdag_idx]] <- explanation_cpdag
  })

}

explanations_causal_discovery_iw <- vector("list", length(causal_discovery_cpdags))
names(explanations_causal_discovery_iw) <- names(causal_discovery_cpdags)
for (cpdag_idx in seq_along(causal_discovery_cpdags)) {

  igraph_dags <- causal_discovery_cpdags[[cpdag_idx]]
  # try
  try ({
    explanation_cpdag <- shapr::explain(
      model = model_to_explain,
      x_train = x_train,
      x_explain = x_to_explain,
      approach = "mec",
      causal_approximation_method = "iw",
      n_samples = n_samples_expectation,
      n_combinations = n_combinations_shapr,
      prediction_zero = prediction_zero,
      seed = seed,
      dags = igraph_dags,
      n_batches = n_batches
    )

    explanations_causal_discovery_iw[[cpdag_idx]] <- explanation_cpdag
  })

}

# Save results
explanations <- list(
  "marginal" = explanation_marginal,
  "conditional" = explanation_conditional,
  "causal" = explanation_causal,
  "causal_discovery_sampling" = explanations_causal_discovery_sampling,
  "causal_discovery_iw" = explanations_causal_discovery_iw
)

saveRDS(explanations, "R/experiments/alzheimers/explanations.rds")


# Load results
explanations <- readRDS("R/experiments/alzheimers/explanations.rds")
explanation_marginal <- explanations$marginal
explanation_conditional <- explanations$conditional
explanation_causal <- explanations$causal
explanations_causal_discovery_sampling <- explanations$causal_discovery_sampling
explanations_causal_discovery_iw <- explanations$causal_discovery_iw

# Calculate statistics
compute_normalized_l2 <- function(shapley_true, shapley_other) {
  total_error <- sqrt(rowSums((shapley_true - shapley_other)^2)) / sqrt(rowSums(shapley_true^2))
  return(total_error)
}

# Create tibble of results
# Method, idx_datapoint, error
tibble_results <- tibble(
  Method = character(),
  idx_datapoint = integer(),
  error = numeric()
)

# Calculate L2 errors
l2_marginal <- compute_normalized_l2(explanation_causal$shapley_values, explanation_marginal$shapley_values)
l2_conditional <- compute_normalized_l2(explanation_causal$shapley_values, explanation_conditional$shapley_values)

# Add to tibble
tibble_results <- rbind(tibble_results, tibble(Method = "Marginal", idx_datapoint = seq_along(l2_marginal), error = l2_marginal))
tibble_results <- rbind(tibble_results, tibble(Method = "Conditional", idx_datapoint = seq_along(l2_conditional), error = l2_conditional))

# Add results from causal discovery, each mec contains only one dag
l2_fges_sampling <- compute_normalized_l2(explanation_causal$shapley_values, explanations_causal_discovery_sampling[[1]]$shapley_values[[1]])
tibble_results <- rbind(tibble_results, tibble(Method = "FGES MEC (sampling)", idx_datapoint = seq_along(l2_fges_sampling), error = l2_fges_sampling))
l2_fges_iw <- compute_normalized_l2(explanation_causal$shapley_values, explanations_causal_discovery_iw[[1]]$shapley_values[[1]])
tibble_results <- rbind(tibble_results, tibble(Method = "FGES MEC (IW)", idx_datapoint = seq_along(l2_fges_iw), error = l2_fges_iw))
l2_pc_sampling <- compute_normalized_l2(explanation_causal$shapley_values, explanations_causal_discovery_sampling[[2]]$shapley_values[[1]])
tibble_results <- rbind(tibble_results, tibble(Method = "PC MEC (sampling)", idx_datapoint = seq_along(l2_pc_sampling), error = l2_pc_sampling))
l2_pc_iw <- compute_normalized_l2(explanation_causal$shapley_values, explanations_causal_discovery_iw[[2]]$shapley_values[[1]])
tibble_results <- rbind(tibble_results, tibble(Method = "PC MEC (IW)", idx_datapoint = seq_along(l2_pc_iw), error = l2_pc_iw))


write_delim(tibble_results, "R/experiments/alzheimers/results.csv")

tibble_results <- read_delim("R/experiments/alzheimers/results.csv")

# Get mean and standard error over datapoints
summary_results <- tibble_results %>%
  group_by(Method) %>%
  summarise(mean_error = mean(error),
            stderr_error = sd(error) / sqrt(n()),
            .groups = 'drop')

# Plot
unique_models <- setdiff(unique(summary_results$Method), c("Marginal", "Conditional"))

# Put 'Marginal' and 'Conditional' at the start of the model list
ordered_models <- c("Marginal", "Conditional", unique_models)
desired_order <- ordered_models

# Add order column
summary_results$Order <- factor(summary_results$Method, levels = desired_order)

g <- ggplot() +
  geom_point(data = summary_results, aes(x = reorder(Method, Order), y = mean_error, color = Method), size = 3) +
    geom_errorbar(data = summary_results, aes(x = reorder(Method, Order),
                                              ymin = mean_error - stderr_error,
                                              ymax = mean_error + stderr_error,
                                              color = Method), width = 0) +
  scale_x_discrete(limits = desired_order) + # This will ensure the order on the x-axis
  labs(title = "ADNI Dataset", x = "Method", y = expression(bar(NL2E))) +
  theme_minimal() +
  theme(axis.text.x = element_text(angle = 45, hjust = 1, size = 12),
        axis.title.x = element_text(size = 14),
        axis.title.y = element_text(size = 14),
        plot.title = element_text(hjust = 0.5),
        legend.position = "none")
ggsave("R/experiments/alzheimers/plot_alzheimers.pdf", g, width = 9, height = 3, dpi = 300)
ggsave("R/experiments/alzheimers/plot_alzheimers_tall.pdf", g, width = 7, height = 3, dpi = 300)

# Sina plot
source('R/experiments/alzheimers/sina_plot.R')
explanation_marginal_sina <- explanation_marginal
explanation_marginal_sina$x_test <- x_to_explain
explanation_marginal_sina$dt <- explanation_marginal$shapley_values

# Reorder columns
explanation_marginal_sina$dt <- explanation_marginal_sina$dt #%>%
  #select(none, SEX, APOE41, APOE42, AGE, EDU, ABETA, FDG, PTAU)

sina_plot_marginal <- sina_plot(explanation_marginal_sina, x_bound = 0.5)
# Change x-axis label
sina_plot_marginal <- sina_plot_marginal + labs(y = "Marginal Shapley value")
ggplot2::ggsave("R/experiments/alzheimers/sina_plot_marginal.pdf", sina_plot_marginal, width = 7, height = 6, dpi = 300)

explanation_conditional_sina <- explanation_conditional
explanation_conditional_sina$x_test <- x_to_explain
explanation_conditional_sina$dt <- explanation_conditional$shapley_values

sina_plot_conditional <- sina_plot(explanation_conditional_sina, x_bound = 0.5)
sina_plot_conditional <- sina_plot_conditional + labs(y = "Conditional Shapley value")
ggplot2::ggsave("R/experiments/alzheimers/sina_plot_conditional.pdf", sina_plot_conditional, width = 7, height = 6, dpi = 300)


explanation_ground_truth_sina <- explanation_causal
explanation_ground_truth_sina$x_test <- x_to_explain
explanation_ground_truth_sina$dt <- explanation_causal$shapley_values

# Reorder columns
explanation_ground_truth_sina$dt <- explanation_ground_truth_sina$dt #%>%
  #select(none, SEX, APOE41, APOE42, AGE, EDU, ABETA, FDG, PTAU)

sina_plot_ground_truth <- sina_plot(explanation_ground_truth_sina, x_bound = 0.5)
sina_plot_ground_truth <- sina_plot_ground_truth + labs(y = "Gold standard causal Shapley value")
ggplot2::ggsave("R/experiments/alzheimers/sina_plot_ground_truth.pdf", sina_plot_ground_truth, width = 7, height = 6, dpi = 300)


