library(tidyverse)
library(lubridate)

# Adult Income ----

dat <- read_csv('data/adult_input.csv')
shapley_values <- read_csv('data/adult_shapley_values.csv')
partial_residuals <- read_csv('data/adult_partial_residuals.csv')

colnames(dat)[1] <- 'row'
colnames(shapley_values)[1] <- 'row'
colnames(partial_residuals)[1] <- 'row'

colnames(shapley_values)[-1] <- str_c(colnames(dat)[-1], '_shap')
colnames(partial_residuals)[-1] <- str_c(colnames(dat)[-1], '_residual')

dat <- dat %>% left_join(shapley_values) %>% left_join(partial_residuals)


dat %>%
  ggplot(aes(x = `sex_shap`, y = `sex_residual`, color = marital_status)) +
  geom_point() +
  facet_wrap(~sex) +
  theme_linedraw() +
  labs(x = 'Sex KernelSHAP value',
       y = 'Sex KernelSHAP scaled residual',
       color = 'Marital status')
ggsave('figs/sex_marital_interaction.png', width = 8, height = 4)




# NHANES -----

dat <- read_csv('data/nhanes_margin_input.csv')
shapley_values <- read_csv('data/nhanes_margin_shapley_values.csv')
partial_residuals <- read_csv('data/nhanes_margin_partial_residuals.csv')

new_colnames <- c('_', 'age', 'diastolic_bp', 'sex', 'systolic_bp', 'poverty',
                  'white_blood_cells', 'bmi')

colnames(dat)[1] <- 'row'
colnames(shapley_values)[1] <- 'row'
colnames(partial_residuals)[1] <- 'row'

colnames(dat)[-1] <- new_colnames
colnames(shapley_values)[-1] <- str_c(new_colnames, '_shap')
colnames(partial_residuals)[-1] <- str_c(new_colnames, '_residual')

dat <- dat %>% left_join(shapley_values) %>% left_join(partial_residuals)

dat %>%
  filter(systolic_bp <= 220) %>%
  ggplot(aes(x = systolic_bp, y = systolic_bp_shap, color = age)) +
  geom_point() +
  scale_color_gradient(low = 'blue', high = 'deeppink') +
  theme_linedraw() +
  labs(x = 'Systolic blood pressure',
       y = 'Systolic blood pressure KernelSHAP value',
       color = 'Age')
ggsave('figs/bp_shap.png', width = 6, height = 4)

dat %>%
  filter(systolic_bp <= 220) %>%
  ggplot(aes(x = systolic_bp, y = systolic_bp_residual, color = age)) +
  geom_point() +
  scale_color_gradient(low = 'blue', high = 'deeppink') +
  theme_linedraw() +
  labs(x = 'Systolic blood pressure',
       y = 'Systolic blood pressure KernelSHAP scaled residual',
       color = 'Age')
ggsave('figs/bp_resid.png', width = 6, height = 4)


dat %>%
  mutate(sex = if_else(sex == 2, 'F', 'M')) %>%
  ggplot(aes(x = age, y = sex_shap, color = sex)) +
  geom_jitter() +
  labs(x = 'Age',
       y = 'Sex KernelSHAP value',
       color = 'Sex')
ggsave('figs/age_sex_interaction.png', width = 6, height = 4)


dat %>%
  mutate(sex = if_else(sex == 2, 'F', 'M')) %>%
  ggplot(aes(x = age, y = sex_residual, color = sex)) +
  geom_jitter() +
  labs(x = 'Age',
       y = 'Sex KernelSHAP scaled residual',
       color = 'Sex')
ggsave('figs/age_sex_interaction_residual.png', width = 6, height = 4)



# Occupancy -----


occupancy <- read_csv('data/occupancy.csv')

dat <- read_csv('data/occupancy_small_input.csv')
shapley_values <- read_csv('data/occupancy_small_shapley_values.csv')
partial_residuals <- read_csv('data/occupancy_small_partial_residuals.csv')

w = 5
h = 4

colnames(dat)[1] <- 'row'
colnames(shapley_values)[1] <- 'row'
colnames(partial_residuals)[1] <- 'row'

colnames(shapley_values)[-1] <- str_c(colnames(dat)[-1], '_shap')
colnames(partial_residuals)[-1] <- str_c(colnames(dat)[-1], '_residual')

dat <- dat %>% left_join(shapley_values) %>% left_join(partial_residuals)

regions <- list(data.frame(xmin = 0, xmax = 24, ymin = 0, ymax = 212.5, pred = 0),
                data.frame(xmin = 0, xmax = 24, ymin = 212.5, ymax = 289.25, pred = 1/80),
                data.frame(xmin = 0, xmax = 12.5, ymin = 289.25, ymax = 369.875, pred = 2/128),
                data.frame(xmin = 12.5, xmax = 24, ymin = 289.25, ymax = 369.875, pred = 1),
                data.frame(xmin = 0, xmax = 7.5, ymin = 369.875, ymax = 563.75, pred = 25/32),
                data.frame(xmin = 7.5, xmax = 24, ymin = 369.875, ymax = 563.75, pred = 1203/1246),
                data.frame(xmin = 0, xmax = 12.5, ymin = 563.75, ymax = 699, pred = 15/21),
                data.frame(xmin = 12.5, xmax = 24, ymin = 563.75, ymax = 699, pred = 5/13)) %>%
  bind_rows()

regions %>%
  ggplot() +
  geom_rect(aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax, fill = pred), alpha = 1) +
  geom_jitter(aes(x = hour, y = light, color = light_residual), data = dat) +
  ylim(c(0, 700)) +
  scale_fill_gradient2(
    low = hcl(30, 30, 80),
    mid = hcl(30, 0, 100),
    high = hcl(210, 30, 80),
    midpoint = 0.5,
    space = "Lab",
    na.value = "grey50",
    aesthetics = "fill"
  ) + 
  scale_color_gradient(
    low = hcl(300, 0, 90),
    high = hcl(300, 0, 0),
    space = "Lab",
    na.value = "grey50",
    guide = "colourbar",
    aesthetics = "color"
  ) +
  theme_linedraw() +
  labs(x = 'Hour', y = 'Light', color = 'Light Residual', fill = 'Prediction')
ggsave('figs/occupancy_light_resid.png', width=w, height=h)


base = 30
regions %>%
  ggplot() +
  geom_rect(aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax, fill = pred), alpha = 1) +
  geom_jitter(aes(x = hour, y = light, color = light_shap), data = dat, width = 0, height = 0) + 
  ylim(c(0, 700)) +
  scale_fill_gradient2(
    low = hcl(base, 30, 80),
    mid = hcl(base, 0, 100),
    high = hcl(base + 180, 30, 80),
    space = "Lab",
    midpoint = 0.5,
    na.value = "grey50",
    aesthetics = "fill"
  ) +
  scale_color_gradient2(
    low = hcl(base + 270, 50, 20),
    mid = hcl(base + 270, 0, 100),
    high = hcl(base + 90, 50, 20),
    space = "Lab",
    midpoint = 0,
    na.value = "grey50",
    guide = "colourbar",
    aesthetics = "color"
  ) +
  theme_linedraw() +
  labs(x = 'Hour', y = 'Light', color = 'Light SHAP', fill = 'Prediction')

ggsave('figs/occupancy_light_shap.png', width=w, height=h)

dat_samp <- occupancy %>% sample_n(100)
dat_samp_light <- dat_samp %>% mutate(light = 320, fixed = 'light')
dat_samp_hour <- dat_samp %>% mutate(hour = 10, fixed = 'hour')
dat_samp <- dat_samp_light %>% bind_rows(dat_samp_hour)
dat_samp %>%
  ggplot() +
  geom_rect(aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax, fill = pred), data = regions, alpha = 1) +
  geom_jitter(aes(x = hour, y = light, color = fixed), data = dat_samp, shape = 'cross') +
  #ylim(c(0, 700)) +
  scale_colour_manual(values = c(hcl(base+120, 60, 40), hcl(base+300, 60, 40))) +
  
  scale_fill_gradient2(
    low = hcl(base, 30, 80),
    mid = hcl(base, 0, 100),
    high = hcl(base + 180, 30, 80),
    space = "Lab",
    midpoint = 0.5,
    na.value = "grey50",
    aesthetics = "fill"
  ) +
  #scale_color_gradient(low = 'blue', high = 'deeppink') +
  theme_linedraw() +
  labs(x = 'Hour', y = 'Light', color = 'Fixed variable', fill = 'Prediction')
ggsave('figs/occupancy_shap_calc.png', width=w, height=h)
