## Script adapted from the original script by Heskes et al. (2020) available at
## https://proceedings.neurips.cc/paper/2020/hash/32e54441e6382a7fbacbbbaf3c450059-Abstract.html

rm(list = ls())
# Set to true in order to save plots in the main folder
# save_plots <- FALSE

# 0 - Load Packages and Source Files --------------------------------------

library(tidyverse)
library(data.table)
library(xgboost)
library(ggpubr)

# For sina plotting capabilities
# source("R/sina_plot.R")

# if (save_plots) {
#   dir.create("figures")
# }

# 1 - Prepare and Plot Data -----------------------------------------------

# Data source: https://archive.ics.uci.edu/ml/datasets/bike+sharing+dataset
 
bike <- read.csv("R/experiments/bikerental/day.csv")
# Difference in days, which takes DST into account
bike$trend <- as.numeric(difftime(bike$dteday, bike$dteday[1], units = "days"))
# bike$trend <- as.integer(difftime(bike$dteday, min(as.Date(bike$dteday)))+1)/24
bike$cosyear <- cospi(bike$trend/365*2)
bike$sinyear <- sinpi(bike$trend/365*2)
# Unnormalize variables (see data set information in link above)
bike$temp <- bike$temp * (39 - (-8)) + (-8)
bike$atemp <- bike$atemp * (50 - (-16)) + (-16)
bike$windspeed <- 67 * bike$windspeed
bike$hum <- 100 * bike$hum


x_var <- c("trend", "cosyear", "sinyear", "temp", "atemp", "windspeed", "hum")
y_var <- "cnt"

# NOTE: Encountered RNG reproducibility issues across different systems, 
# so we saved the training-test split.
# set.seed(2013)
# train_index <- caret::createDataPartition(bike$cnt, p = .8, list = FALSE, times = 1)
train_index <- readRDS("R/experiments/bikerental/train_index.rds")

# Training data
x_train <- as.matrix(bike[train_index, x_var])
y_train_nc <- as.matrix(bike[train_index, y_var]) # not centered
y_train <- y_train_nc - mean(y_train_nc) 

# Test data
x_test <- as.matrix(bike[-train_index, x_var])
y_test_nc <- as.matrix(bike[-train_index, y_var]) # not centered
y_test <- y_test_nc - mean(y_train_nc) 

# Fit an XGBoost model to the training data
model <- xgboost(
  data = x_train,
  label = y_train,
  nround = 100,
  verbose = FALSE
)
# caret::RMSE(y_test, predict(model, x_test))

# message("1. Prepared and plotted data, trained XGBoost model")
#
# # 2 - Compute Shapley Values ----------------------------------------------
#

## Load code from Heskes et al. for causal-chain and (their version of) marginal Shapley values
## available at https://proceedings.neurips.cc/paper/2020/hash/32e54441e6382a7fbacbbbaf3c450059-Abstract.html
devtools::load_all("shapr-heskes")

explainer_symmetric <- shapr(x_train, model)
p <- mean(y_train)

# a. We compute the causal Shapley values on a given partial order (see paper)
partial_order <- list(1, c(2, 3), c(4:7))

explanation_causal <- explain(
  x_test,
  approach = "causal",
  explainer = explainer_symmetric,
  prediction_zero = p,
  ordering = partial_order,
  confounding = c(FALSE, TRUE, FALSE),
  seed = 2020
)
# Save explanation
saveRDS(explanation_causal, "R/experiments/bikerental/explanation_causal_heskes.rds")


# b. For computing marginal Shapley values, we assume one component with confounding
explanation_marginal <- explain(
  x_test,
  approach = "causal",
  explainer = explainer_symmetric,
  prediction_zero = p,
  ordering = list(c(1:7)),
  confounding = TRUE,
  seed = 2020
)
# Save explanation
saveRDS(explanation_marginal, "R/experiments/bikerental/explanation_marginal_heskes.rds")


# sina_marginal <- sina_plot(explanation_marginal) +
#   coord_flip(ylim = ylim_causal) + ylab("Marginal Shapley value (impact on model output)")
#
# if (save_plots) {
#   ggsave("figures/sina_plot_marginal.pdf", sina_marginal, height = 6.5, width = 6.5)
# } else {
#   print(sina_marginal)
# }

# message("2b. Computed and plotted marginal Shapley values")

## Load our shapr package
detach("package:shapr", unload=TRUE)
devtools::load_all("shapr")
seed <- 2020 # Same seed as Heskes
n_combinations_shapr <- NULL
n_samples_expectation <- 1000
n_batches <- 8
prediction_zero <- p
x_to_explain <- x_test

# conditional 
explanation_conditional <- shapr::explain(
  model = model,
  x_train = x_train,
  x_explain = x_to_explain,
  prediction_zero = prediction_zero,
  n_samples = n_samples_expectation,
  n_combinations = n_combinations_shapr,
  seed = seed,
  n_batches = n_batches,
  approach = "gaussian"
)
saveRDS(explanation_conditional, "R/experiments/bikerental/explanation_conditional.rds")

mec_fges <- readRDS("R/experiments/bikerental/mec_fges.rds")
explanation_mec_fges <- shapr::explain(
  model = model,
  x_train = x_train,
  x_explain = x_to_explain,
  prediction_zero = prediction_zero,
  n_samples = n_samples_expectation,
  n_combinations = n_combinations_shapr,
  seed = seed,
  n_batches = n_batches,
  approach = "mec",
  dags = mec_fges
)
# Save explanation
saveRDS(explanation_mec_fges, "R/experiments/bikerental/explanation_mec_fges.rds")

mec_pc <- readRDS("R/experiments/bikerental/mec_pc.rds")
explanation_mec_pc <- shapr::explain(
  model = model,
  x_train = x_train,
  x_explain = x_to_explain,
  prediction_zero = prediction_zero,
  n_samples = n_samples_expectation,
  n_combinations = n_combinations_shapr,
  seed = seed,
  n_batches = n_batches,
  approach = "mec",
  dags = mec_pc
)
# Save explanation
saveRDS(explanation_mec_pc, "R/experiments/bikerental/explanation_mec_pc.rds")

# Load explanations
explanation_marginal <- readRDS("R/experiments/bikerental/explanation_marginal_heskes.rds")
explanation_causal <- readRDS("R/experiments/bikerental/explanation_causal_heskes.rds")
explanation_conditional <- readRDS("R/experiments/bikerental/explanation_conditional.rds")
explanation_mec_pc <- readRDS("R/experiments/bikerental/explanation_mec_pc.rds")
explanation_mec_fges <- readRDS("R/experiments/bikerental/explanation_mec_fges.rds")

# Bar plots
october <- which(as.integer(row.names(x_test)) == which(bike$dteday == "2012-10-09"))
december <- which(as.integer(row.names(x_test)) == which(bike$dteday == "2012-12-03"))


dt_marginal <- explanation_marginal$dt %>%
  dplyr::slice(c(october, december)) %>%
  select(cosyear, temp) %>%
  mutate(date = c("2012-10-09", "2012-12-03"), type = 'Marginal')

dt_conditional <- explanation_conditional$shapley_values %>%
  dplyr::slice(c(october, december)) %>%
  select(cosyear, temp) %>%
  mutate(date = c("2012-10-09", "2012-12-03"), type = 'Conditional')

dt_causal <- explanation_causal$dt %>%
  dplyr::slice(c(october, december)) %>%
  select(cosyear, temp) %>%
  mutate(date = c("2012-10-09", "2012-12-03"), type = 'Causal (chain)')

dt_fges <- explanation_mec_fges$shapley_values[[2]] %>%
  dplyr::slice(c(october, december)) %>%
  mutate(date = c("2012-10-09", "2012-12-03"), type = 'FGES') %>%
  rename(cosyear = V3, temp = V5) %>%
  select(cosyear, temp, date, type)

dt_pc <- explanation_mec_pc$shapley_values[[1]] %>%
  dplyr::slice(c(october, december)) %>%
  mutate(date = c("2012-10-09", "2012-12-03"), type = 'PC') %>%
  rename(cosyear = V3, temp = V5) %>%
  select(cosyear, temp, date, type)

dt_all <- dt_marginal %>% pivot_longer(c(cosyear, temp)) %>%
  rbind(dt_causal %>% pivot_longer(c(cosyear, temp))) %>%
  rbind(dt_pc %>% pivot_longer(c(cosyear, temp)))

# nr of unique types
ncol = length(unique(dt_all$type))
bar_plots <- ggplot(dt_all, aes(x = name, y = value, group = interaction(date, name),
                         fill = date, label = round(value, 2))) +
  geom_col(position = "dodge") +
  theme_classic() + ylab("Shapley value") +
  facet_wrap(vars(type), ncol=ncol) + theme(axis.title.x = element_blank()) +
  scale_fill_manual(values = c('indianred4', 'ivory4')) +
  theme(
        legend.position = "top", 
        legend.direction = "horizontal", 
        axis.title = element_text(size = 20),
        legend.title = element_text(size = 16), 
        legend.text = element_text(size = 14),
        axis.text.x = element_text(size = 12), 
        axis.text.y = element_text(size = 12),
        strip.text.x = element_text(size = 14)
  )

ggsave("R/experiments/bikerental/bar_plot_marg_causal_pc.pdf", bar_plots, width = 6, height = 3)

dt_all <- dt_marginal %>% pivot_longer(c(cosyear, temp)) %>%
  rbind(dt_conditional %>% pivot_longer(c(cosyear, temp))) %>%
  rbind(dt_causal %>% pivot_longer(c(cosyear, temp))) %>%
  rbind(dt_pc %>% pivot_longer(c(cosyear, temp)))

# nr of unique types
ncol = length(unique(dt_all$type))
bar_plots <- ggplot(dt_all, aes(x = name, y = value, group = interaction(date, name),
                         fill = date, label = round(value, 2))) +
  geom_col(position = "dodge") +
  theme_classic() + ylab("Shapley value") +
  facet_wrap(vars(type), ncol=ncol) + theme(axis.title.x = element_blank()) +
  scale_fill_manual(values = c('indianred4', 'ivory4')) +
  theme(
        legend.position = "top", 
        legend.direction = "horizontal", 
        axis.title = element_text(size = 20),
        legend.title = element_text(size = 16), 
        legend.text = element_text(size = 14),
        axis.text.x = element_text(size = 12), 
        axis.text.y = element_text(size = 12),
        strip.text.x = element_text(size = 14)
  )



ggsave("R/experiments/bikerental/bar_plot_marg_cond_causal_pc.pdf", bar_plots, width = 8, height = 5)


# Bar plot sinyear

dt_marginal <- explanation_marginal$dt %>%
  dplyr::slice(c(october, december)) %>%
  select(sinyear, temp) %>%
  mutate(date = c("2012-10-09", "2012-12-03"), type = 'Marginal')

dt_causal <- explanation_causal$dt %>%
  dplyr::slice(c(october, december)) %>%
  select(sinyear, temp) %>%
  mutate(date = c("2012-10-09", "2012-12-03"), type = 'Causal')

dt_fges_trivial_bg <- explanation_mec_fges_sembic_trivial_bg$shapley_values[[2]] %>%
  dplyr::slice(c(october, december)) %>%
  mutate(date = c("2012-10-09", "2012-12-03"), type = 'FGES') %>%
  rename(sinyear = V4, temp = V5) %>%
  select(sinyear, temp, date, type)


dt_all <- dt_marginal %>% pivot_longer(c(sinyear, temp)) %>%
  rbind(dt_causal %>% pivot_longer(c(sinyear, temp))) %>%
  rbind(dt_fges_trivial_bg %>% pivot_longer(c(sinyear, temp)))

bar_plots <- ggplot(dt_all, aes(x = name, y = value, group = interaction(date, name),
                         fill = date, label = round(value, 2))) +
  geom_col(position = "dodge") +
  theme_classic() + ylab("Shapley value") +
  facet_wrap(vars(type)) + theme(axis.title.x = element_blank()) +
  scale_fill_manual(values = c('indianred4', 'ivory4')) +
  theme(legend.position = c(0.75, 0.35), axis.title = element_text(size = 20),
        legend.title = element_text(size = 16), legend.text = element_text(size = 14),
        axis.text.x = element_text(size = 12), axis.text.y = element_text(size = 12),
        strip.text.x = element_text(size = 14))


if (save_plots) {
  ggsave("R/experiments/bikerental/bar_plot_fges_trivial_bg_sinyear.pdf", bar_plots, width = 6, height = 3)
} else {
  print(bar_plots)
}


## Sina plots
source('R/experiments/bikerental/sina_plot.R')
height = 7
width = 7

sina_marginal <- sina_plot(explanation_marginal, x_bound = 2000) +
  #coord_flip(ylim = ylim_causal) +
  ylab("Marginal Shapley value")
ggplot2::ggsave("R/experiments/bikerental/sina_plot_marginal.pdf", sina_marginal, height = height, width = width, dpi = 300)


explanation_conditional$x_test <- x_test
explanation_conditional$dt <- explanation_conditional$shapley_values
sina_conditional <- sina_plot(explanation_conditional, x_bound = 2000) +
  #coord_flip(ylim = ylim_causal) +
  ylab("Conditional Shapley value")
ggplot2::ggsave("R/experiments/bikerental/sina_plot_conditional.pdf", sina_conditional, height = height, width = width, dpi = 300)

sina_causal <- sina_plot(explanation_causal) +
    #coord_flip(ylim = ylim_causal) +
    ylab("Causal (chain) Shapley value")
ggplot2::ggsave("R/experiments/bikerental/sina_plot_causal.pdf", sina_causal, height = height, width = width, dpi = 300)

explanation_mec_fges$x_test <- x_test
explanation_mec_fges$dt <- explanation_mec_fges$shapley_values[[1]]
colnames(explanation_mec_fges$dt) <- colnames(explanation_causal$dt)
sina_fges_one <- sina_plot(explanation_mec_fges) +
    #coord_flip(ylim = ylim_causal) +
    ylab("FGES Shapley value (windspeed -> hum)")
ggplot2::ggsave("R/experiments/bikerental/sina_plot_fges_one.pdf", sina_fges_one, height = height, width = width, dpi = 300)

explanation_mec_fges$dt <- explanation_mec_fges$shapley_values[[2]]
colnames(explanation_mec_fges$dt) <- colnames(explanation_causal$dt)
sina_fges_two <- sina_plot(explanation_mec_fges) +
    #coord_flip(ylim = ylim_causal) +
    ylab("FGES Shapley value (hum -> windspeed)")
ggplot2::ggsave("R/experiments/bikerental/sina_plot_fges_two.pdf", sina_fges_two, height = height, width = width, dpi = 300)



explanation_mec_pc$x_test <- x_test
explanation_mec_pc$dt <- explanation_mec_pc$shapley_values[[1]]
# Re-add Column names
colnames(explanation_mec_pc$dt) <- colnames(explanation_causal$dt)
sina_pc <- sina_plot(explanation_mec_pc, x_bound = 2000) +
    #coord_flip(ylim = ylim_causal) +
    ylab("PC Shapley value")
ggplot2::ggsave("R/experiments/bikerental/sina_plot_pc.pdf", sina_pc, height = height, width = width, dpi = 300)
