# Load necessary libraries
library(ggplot2)
library(dplyr)
library(tidyr)
library(gridExtra)
library(grid)
library(ggthemes)
library(scales)

# Define the custom theme and scale functions from plot-bar-synthetic.r
theme_Publication <- function(base_size=14, base_family="sans") {
  (theme_foundation(base_size=base_size, base_family=base_family)
   + theme(plot.title = element_text(face = "bold",
                                     size = rel(1.2), hjust = 0.5),
           text = element_text(),
           panel.background = element_rect(colour = NA),
           plot.background = element_rect(colour = NA),
           panel.border = element_rect(colour = NA),
           axis.title = element_text(face = "bold",size = rel(1)),
           axis.title.y = element_text(angle=90,vjust =2),
           axis.title.x = element_text(vjust = -0.2),
           axis.text = element_text(), 
           axis.line = element_line(colour="black"),
           axis.ticks = element_line(),
           panel.grid.major = element_line(colour="#f0f0f0"),
           panel.grid.minor = element_blank(),
           legend.key = element_rect(colour = NA),
           legend.position = "bottom",
           legend.direction = "horizontal",
           legend.key.size= unit(0.2, "cm"),
           legend.spacing = unit(0, "cm"),
           legend.title = element_text(face="italic"),
           plot.margin=unit(c(10,5,5,5),"mm"),
           strip.background=element_rect(colour="#f0f0f0",fill="#f0f0f0"),
           strip.text = element_text(face="bold")
   ))
}

scale_fill_Publication <- function(...){
  discrete_scale("fill","Publication",manual_pal(values = c("#0f7ba2","#7fc97f","#dd5129")), ...)
}

# Create data frame with subgroup sequences and their treatment probabilities
treatment_data <- data.frame(
  sequence = c(
    # Sequence 1: Age-based (Increasing Risk)
    "Overall", "Age ≥ 65", "Age ≥ 65, BMI ≥ 30", "Age ≥ 65, BMI ≥ 30, Past Smoker",
    
    # Sequence 2: Physical Function (Increasing Risk)
    "Overall", "Low Function", "Low Function, Age ≥ 70", 
    "Low Function, Age ≥ 70, Current Smoker",
    
    # Sequence 3: Very Elderly (Increasing Risk)
    "Overall", "Age ≥ 75", "Age ≥ 75, Low Function", 
    "Age ≥ 75, Low Function, Severe Obesity"
  ),
  
  ### A signal
  probability = c(
    # Sequence 1 probabilities (Age-based Increasing Risk)
    0.3330, 0.2155, 0.1397, 0.1575,
    
    # Sequence 2 probabilities (Function-based Increasing Risk)
    0.3330, 0.2970, 0.1807, 0.2289,
    
    # Sequence 3 probabilities (Very Elderly Increasing Risk)
    0.3330, 0.1231, 0.1436, 0.0761
  ),
  
  ### Y signal
  probability = c(
    0.0173, 0.0287, 0.0427, 
  )
  
  seq_group = c(
    # Group labels for each sequence
    rep("Age-based", 4),
    rep("Function-based", 4),
    rep("Very Elderly", 4)
  ),
  
  subgroup_num = c(
    # Position within each sequence
    1, 2, 3, 4,
    1, 2, 3, 4,
    1, 2, 3, 4
  )
)

# Create factor levels to ensure proper ordering
treatment_data$seq_group <- factor(treatment_data$seq_group, 
                                  levels = c("Age-based",
                                            "Function-based",
                                            "Very Elderly"))

treatment_data$sequence <- factor(treatment_data$sequence, 
                                 levels = unique(treatment_data$sequence))

# # Create the bar plot
# p <- ggplot(treatment_data, aes(x = factor(subgroup_num), y = probability, 
#                                fill = sequence, group = interaction(seq_group, subgroup_num))) +
#   geom_bar(stat = "identity", position = position_dodge(width = 0.9), width = 0.8) +
#   facet_grid(. ~ seq_group, scales = "free_x", space = "free_x") +
#   labs(x = "",
#        y = "",
#        fill  = "Subgroup") +
#   scale_y_continuous(limits = c(0, 0.6), breaks = seq(0, 0.6, by = 0.1)) +
#   theme_Publication() +
#   scale_fill_Publication() +
#   theme(
#     legend.position = "bottom",
#     strip.text = element_text(size = rel(0.9)),
#     panel.spacing = unit(1, "lines"),
#     axis.text.x = element_text(angle = 0)
#   )
# 
# # Display the plot
# print(p)
# 
# # Save the plot
# ggsave("treatment_probability_risk_factors.png", p, width = 15, height = 4, dpi = 300)

# Alternative visualization with connected points to show progression
p2 <- ggplot(treatment_data, aes(x = factor(subgroup_num), y = probability, 
                                group = seq_group, color = seq_group)) +
  geom_line(size = 1.2) +
  geom_point(size = 3, shape = 21, fill = "white") +
  facet_grid(. ~ seq_group, scales = "free_x", space = "free_x") +
  labs(x = "",
       y = "",
       color = "Sequence Group") +
  scale_y_continuous(limits = c(0, 0.4), breaks = seq(0, 0.4, by = 0.1)) +
  theme_Publication() +
  theme(
    legend.position = "bottom",
    strip.text = element_text(size = rel(0.9)),
    panel.spacing = unit(1, "lines")
  )

# Display the alternative plot
print(p2)

# Save the alternative plot
# ggsave("treatment_probability_risk_progression.png", p2, width = 12, height = 8, dpi = 300)