library(ggplot2)
library(forcats)
################################################################################
########################### Run this for setup =5,6,7 ##########################
################################################################################
# setup = 5
# rst_setup <- read.table(file = sprintf('../Results_LR_Summary/setup%d_JNA.txt', setup), header=TRUE)
# rst_setup_J5 <- read.table(file = sprintf('../Results_LR_Summary/setup%d_J5.txt', setup), header=TRUE)
# rst_setup_J20 <- read.table(file = sprintf('../Results_LR_Summary/setup%d_J20.txt', setup), header=TRUE)
# rst_setup_J20 <- read.table(file = sprintf('../Results_LR_Summary/setup%d_J20.txt', setup), header=TRUE)
# 
# rst_setup <- rst_setup %>%
#   select(-src) %>%
#   mutate(setup = setup)
# 
# rst_setup_J5 <- rst_setup_J5 %>%
#   select(-src) %>%
#   mutate(setup = setup)
# 
# rst_setup_J10 <- rst_setup_J10 %>%
#   select(-src) %>%
#   mutate(setup = setup)
# 
# rst_setup_J20 <- rst_setup_J20 %>%
#   select(-src) %>%
#   mutate(setup = setup)
# ################################################################################
# 
# rst_setup <- rbind(rst_setup_JNA, rst_setup_J5, rst_setup_J10, rst_setup_J20)
# write.table(rst_setup,
#             file = sprintf('../Results_LR_Summary/setup%d.txt', setup),
#             row.names = FALSE, col.names = TRUE,
#             quote = FALSE, sep = "\t")
# 
# dat_setup5 <- read.table('../Results_LR_Summary/setup5.txt', header = TRUE)
# dat_setup6 <- read.table('../Results_LR_Summary/setup6.txt', header = TRUE)
# dat_setup7 <- read.table('../Results_LR_Summary/setup7.txt', header = TRUE)
# write.table(rbind(dat_setup5, dat_setup6, dat_setup7),
#             file = sprintf('Results_LR_Summary/LR_summary.txt'),
#             row.names = FALSE, col.names = TRUE,
#             quote = FALSE, sep = "\t")
# rst_true_setup5 <- read.table('Results_LR_Summary/n_30000_setup5_true.txt', header = TRUE) %>%
#   mutate(setup = 5)
# rst_true_setup6 <- read.table('Results_LR_Summary/n_30000_setup6_true.txt', header = TRUE) %>%
#   mutate(setup = 6)
# rst_true_setup7 <- read.table('Results_LR_Summary/n_30000_setup7_true.txt', header = TRUE) %>%
#   mutate(setup = 7)
# write.table(rbind(rst_true_setup5, rst_true_setup6, rst_true_setup7),
#             file = sprintf('Results_LR_Summary/LR_True_ACE.txt'),
#             row.names = FALSE, col.names = TRUE,
#             quote = FALSE, sep = "\t")

################################################################################
######################## Computational Efficiency Table ########################
################################################################################
rst <- read.table('../Results_LR_Summary/LR_summary.txt', header = TRUE)
rst <- rst %>% mutate(
  J = if_else(is.na(J), 'NA', sprintf('%s', J))
)
rst_true <- read.table('../Results_LR_Summary/LR_True_ACE.txt', header = TRUE)

rst <- rst %>%
  left_join(rst_true, by = c('Time', 'treatment', 'setup'))

summary_time_each_simu <- rst %>% 
  select(setup, J, treatment, iterative, Computation_Data, Computation_Fit) %>%
  mutate(Computation_Data = ifelse(is.na(Computation_Data), 0 , Computation_Data)) %>%
  group_by(setup, J, treatment, iterative) %>% 
  summarise(All_Computation = 
              # mean(Computation_Fit + Computation_Data, na.rm=TRUE),
              sprintf('%.2f (%.2f)', mean(Computation_Fit + Computation_Data, na.rm=TRUE),
                      sd(Computation_Fit + Computation_Data, na.rm=TRUE)),
            Fitting_Computation = 
              sprintf('%.2f (%.2f)', mean(Computation_Fit, na.rm=TRUE),
                      sd(Computation_Fit, na.rm=TRUE)),
  ) %>% 
  ungroup()

summary_time_each_simu %>% 
  filter(setup == 5) %>%
  mutate(matchedway = case_when(J == 5    ~ 'Case-control (J=5)',
                                J == 10    ~ 'Case-control (J=10)',
                                J == 20    ~ 'Case-control (J=20)',
                                J == 'NA'    ~ 'Complete')) %>%
  mutate(#col_label = paste0("iterate", iterative, "_treat", treatment)
    col_label = sprintf('%sICE (%s)', 
                        ifelse(iterative, '', 'N'), 
                        ifelse(treatment, 'A=1', 'A=0'))
  ) %>%
  select(matchedway, col_label, All_Computation) %>%
  pivot_wider(
    names_from = col_label,
    values_from = All_Computation#Fitting_Computation
  ) %>% 
  mutate(
    matchedway = factor(matchedway, 
                        levels = c('Complete', 
                                   'Case-control (J=20)', 
                                   'Case-control (J=10)', 
                                   'Case-control (J=5)'))) %>%
  arrange(matchedway) %>%
  knitr::kable(booktab = TRUE) %>%
  # knitr::kable(format = 'latex', booktab = TRUE) %>%
kableExtra::kable_styling()


################################################################################
########################## Estimation Efficiency Table #########################
################################################################################
summary_each_simu <- rst %>% 
  group_by(setup, J, treatment, iterative, Time, dataID) %>%
  summarise(Boots_mean = mean(Risk), 
            Boots_sd = sd(Risk),
            Quantile_lower = quantile(Risk, 0.05),
            Quantile_upper = quantile(Risk, 0.95),
            Quantile_cover = I(Quantile_upper > TrueRisk[1] & Quantile_lower < TrueRisk[1]),
            HPD_lower = coda::HPDinterval(coda::mcmc(Risk), prob = 0.9)[1],
            HPD_upper = coda::HPDinterval(coda::mcmc(Risk), prob = 0.9)[2],
            HPD_cover = I(HPD_upper > TrueRisk[1] & HPD_lower < TrueRisk[1]),
            Aysm_lower = Boots_mean - qnorm(0.95) * Boots_sd,
            Aysm_upper = Boots_mean + qnorm(0.95) * Boots_sd, 
            Aysm_cover = I(Aysm_upper > TrueRisk[1] & Aysm_lower < TrueRisk[1])) %>%
  ungroup()

summary_all <- summary_each_simu %>% 
  group_by(setup, J, iterative, treatment, Time) %>%
  summarise(mean_Boots_mean = mean(Boots_mean),
            mean_Boots_sd = mean(Boots_sd),
            sd_Boots_mean = sd(Boots_mean),
            Quantile_cover_rate = mean(Quantile_cover),
            HPD_cover_rate = mean(HPD_cover),
            Aysm_cover_rate = mean(Aysm_cover)) %>%
  ungroup()

bar_a <- 1
summary_all %>%
  filter(setup == 5 & treatment == bar_a) %>%
  arrange(fct_relevel(J, "5", "10", "20", "NA")) %>%
  mutate(col_label = sprintf('%sICE (%s, %s)', 
                             ifelse(iterative, '', 'N'), 
                             ifelse(J!='NA', sprintf('Match 1:%s',J), 'Complete'),
                             ifelse(treatment, 'A=1', 'A=0')
  )) %>%
  group_by(Time, col_label) %>%
  mutate(mean_sd = sprintf('%.2f (%.2f)', 100*mean(mean_Boots_mean), 100*mean(mean_Boots_sd))) %>%
  ungroup() %>%
  select(Time, col_label, mean_sd) %>%
  mutate(col_label = factor(col_label,
                            levels = c(sprintf("ICE (Match 1:5, A=%d)", bar_a),
                                       sprintf("ICE (Match 1:10, A=%d)", bar_a),
                                       sprintf("ICE (Match 1:20, A=%d)", bar_a),
                                       sprintf("ICE (Complete, A=%d)", bar_a),
                                       sprintf("NICE (Match 1:5, A=%d)", bar_a),
                                       sprintf("NICE (Match 1:10, A=%d)", bar_a),
                                       sprintf("NICE (Match 1:20, A=%d)", bar_a),
                                       sprintf("NICE (Complete, A=%d)", bar_a)))) %>%
  pivot_wider(
    names_from = col_label,
    values_from = mean_sd,
    names_sort = TRUE
  ) %>%
  knitr::kable(booktab = TRUE) %>%
  # knitr::kable(format = 'latex', booktab = TRUE) %>%
  kableExtra::kable_styling()       


################################################################################
############################ Risk Estimators Figure ############################
################################################################################
dfp <- summary_each_simu %>%
  filter(setup == 5 & iterative == 1) %>%
  mutate(
    J = suppressWarnings(as.integer(ifelse(J %in% c("NA",""), NA, J))),
    treatment = factor(treatment),  # ensures consistent dodge
    treatment_lab = ifelse(treatment == 1, "bar(A)==bar(1)", "bar(A)==bar(0)"),
    matched_lab = case_when(
      is.na(J) ~ "(Complete)",
      J == 5   ~ "(Case-control, J=5)",
      J == 10  ~ "(Case-control, J=10)",
      J == 20  ~ "(Case-control, J=20)",
      TRUE     ~ paste0("(Case-control, J=", J, ")")
    ),
    matched_lab = factor(
      matched_lab,
      levels = c("(Complete)", "(Case-control, J=20)", "(Case-control, J=10)", "(Case-control, J=5)")
    ),
    Time = factor(Time, levels = sort(unique(Time)))
  )

ggplot(dfp, aes(x = Time, y = Boots_mean, fill = matched_lab)) +
  geom_boxplot(
    position = position_dodge2(width = 0.8, preserve = "single"),
    width = 0.8,
    outlier.size = 0.6,
    linewidth = 0.3
  ) +
  facet_wrap(~ treatment_lab, labeller = label_parsed) +
  labs(x = "Time", y = "Risks", fill = "Matching scheme") +
  theme_minimal(base_size = 14) +
  theme(legend.position = "bottom") +
  guides(fill = guide_legend(nrow = 2))
