# BART vs BAST (single tree) ----------------------------------------------

# Cartoon: single BART vs single BAST
# Uses gridExtra (NOT patchwork)

library(ggplot2)
library(dplyr)
library(igraph)
library(gridExtra)

set.seed(1)


make_grid_graph <- function(m = 18) {
  xs <- seq(0, 1, length.out = m)
  ys <- seq(0, 1, length.out = m)
  grid <- expand.grid(x = xs, y = ys) %>%
    arrange(y, x) %>%
    mutate(id = row_number())
  
  idx <- function(i, j) (i - 1) * m + j
  
  edges <- list()
  for (i in 1:m) for (j in 1:m) {
    u <- idx(i, j)
    if (j < m) edges[[length(edges) + 1]] <- c(u, idx(i, j + 1))
    if (i < m) edges[[length(edges) + 1]] <- c(u, idx(i + 1, j))
  }
  edges <- do.call(rbind, edges)
  
  g <- graph_from_edgelist(edges, directed = FALSE)
  V(g)$x <- grid$x
  V(g)$y <- grid$y
  
  list(g = g, grid = grid)
}


bart_partition <- function(grid) {
  grid %>%
    mutate(
      leaf = case_when(
        x < 0.35 & y < 0.55 ~ "L1",
        x < 0.35 & y >= 0.55 ~ "L2",
        x >= 0.35 & x < 0.70 & y < 0.35 ~ "M1",
        x >= 0.35 & x < 0.70 & y >= 0.35 ~ "M2",
        x >= 0.70 ~ "R"
      )
    )
}


bast_partition <- function(g, K = 8) {
  E(g)$w <- runif(ecount(g))
  T <- mst(g, weights = E(g)$w)
  
  cut_edges <- sample(E(T), K)
  T_cut <- delete_edges(T, cut_edges)
  
  comp <- components(T_cut)$membership
  list(tree = T, tree_cut = T_cut, comp = comp, cut_edges = cut_edges)
}

edge_df <- function(g) {
  el <- as.data.frame(get.edgelist(g))
  colnames(el) <- c("from", "to")
  el %>%
    mutate(
      x = V(g)$x[as.integer(from)],
      y = V(g)$y[as.integer(from)],
      xend = V(g)$x[as.integer(to)],
      yend = V(g)$y[as.integer(to)]
    )
}


obj <- make_grid_graph(m = 18)
g <- obj$g
grid <- obj$grid

grid_bart <- bart_partition(grid)

bast_fit <- bast_partition(g, K = 8)
grid_bast <- grid %>%
  mutate(region = factor(bast_fit$comp[id]))

edges_tree <- edge_df(bast_fit$tree)
edges_cut <- edges_tree[as.integer(bast_fit$cut_edges), ]


p_bart <- ggplot(grid_bart, aes(x, y)) +
  geom_tile(aes(fill = leaf), width = 0.055, height = 0.055) +
  geom_vline(xintercept = c(0.35, 0.70), linewidth = 0.6) +
  geom_hline(yintercept = c(0.55, 0.35), linewidth = 0.6) +
  coord_fixed() +
  labs(
    title = "Single BART tree",
    subtitle = "Axis-aligned rectangular partitions",
    x = NULL, y = NULL
  ) +
  theme_minimal(base_size = 12) +
  theme(legend.position = "none", panel.grid = element_blank())


p_bast <- ggplot() +
  geom_tile(
    data = grid_bast,
    aes(x, y, fill = region),
    width = 0.055, height = 0.055
  ) +
  geom_segment(
    data = edges_tree,
    aes(x, y, xend = xend, yend = yend),
    linewidth = 0.25, alpha = 0.35
  ) +
  geom_segment(
    data = edges_cut,
    aes(x, y, xend = xend, yend = yend),
    linewidth = 1.1
  ) +
  coord_fixed() +
  labs(
    title = "Single BAST component",
    subtitle = "Tree-edge cuts → contiguous spatial regions",
    x = NULL, y = NULL
  ) +
  theme_minimal(base_size = 12) +
  theme(legend.position = "none", panel.grid = element_blank())


grid.arrange(
  p_bart, p_bast,
  ncol = 2,
  top = "Cartoon comparison: BART vs BAST"
)


# BAST versus sBAST -------------------------------------------------------
library(ggplot2)
library(dplyr)
library(igraph)
library(gridExtra)

set.seed(123)

make_grid_graph <- function(m = 18) {
  xs <- seq(0, 1, length.out = m)
  ys <- seq(0, 1, length.out = m)
  grid <- expand.grid(x = xs, y = ys) %>%
    arrange(y, x) %>%
    mutate(id = row_number())
  
  idx <- function(i, j) (i - 1) * m + j
  
  edges <- list()
  for (i in 1:m) for (j in 1:m) {
    u <- idx(i, j)
    if (j < m) edges[[length(edges) + 1]] <- c(u, idx(i, j + 1))
    if (i < m) edges[[length(edges) + 1]] <- c(u, idx(i + 1, j))
  }
  edges <- do.call(rbind, edges)
  
  g <- graph_from_edgelist(edges, directed = FALSE)
  V(g)$x <- grid$x
  V(g)$y <- grid$y
  
  list(g = g, grid = grid)
}

single_bast <- function(g, K = 8) {
  E(g)$w <- runif(ecount(g))
  T <- mst(g, weights = E(g)$w)
  
  cut_edges <- sample(E(T), K)
  T_cut <- delete_edges(T, cut_edges)
  
  comp <- components(T_cut)$membership
  list(tree = T, comp = comp, cut_edges = cut_edges)
}


soft_bast_values <- function(tree, anchor_nodes, tau = 8) {
  D <- distances(tree)     # tree distances
  n <- nrow(D)
  
  f <- rep(0, n)
  for (k in seq_along(anchor_nodes)) {
    w <- exp(-tau * D[anchor_nodes[k], ]^2)
    w <- w / sum(w)
    f <- f + w * rnorm(1)  # random node-level effect
  }
  f
}


obj <- make_grid_graph(m = 18)
g <- obj$g
grid <- obj$grid

bast_fit <- single_bast(g, K = 8)

grid_bast <- grid %>%
  mutate(region = factor(bast_fit$comp[id]))

anchor_nodes <- sample(V(g), 5)
grid_soft <- grid %>%
  mutate(value = soft_bast_values(bast_fit$tree, anchor_nodes, tau = 10)[id])


edge_df <- function(gr) {
  el <- as.data.frame(get.edgelist(gr))
  colnames(el) <- c("from", "to")
  el %>%
    mutate(
      x = V(gr)$x[as.integer(from)],
      y = V(gr)$y[as.integer(from)],
      xend = V(gr)$x[as.integer(to)],
      yend = V(gr)$y[as.integer(to)]
    )
}

edges_tree <- edge_df(bast_fit$tree)
edges_cut  <- edges_tree[as.integer(bast_fit$cut_edges), ]


p_bast <- ggplot() +
  geom_tile(
    data = grid_bast,
    aes(x, y, fill = region),
    width = 0.055, height = 0.055
  ) +
  geom_segment(
    data = edges_tree,
    aes(x, y, xend = xend, yend = yend),
    linewidth = 0.25, alpha = 0.3
  ) +
  geom_segment(
    data = edges_cut,
    aes(x, y, xend = xend, yend = yend),
    linewidth = 1.1
  ) +
  coord_fixed() +
  labs(
    title = "Single BAST",
    subtitle = "Hard tree cuts → piecewise-constant regions",
    x = NULL, y = NULL
  ) +
  theme_minimal(base_size = 12) +
  theme(legend.position = "none", panel.grid = element_blank())

p_soft <- ggplot() +
  geom_tile(
    data = grid_soft,
    aes(x, y, fill = value),
    width = 0.055, height = 0.055
  ) +
  geom_segment(
    data = edges_tree,
    aes(x, y, xend = xend, yend = yend),
    linewidth = 0.25, alpha = 0.35
  ) +
  scale_fill_viridis_c() +
  coord_fixed() +
  labs(
    title = "Single Soft-BAST",
    subtitle = "Soft averaging along tree distance",
    x = NULL, y = NULL
  ) +
  theme_minimal(base_size = 12) +
  theme(panel.grid = element_blank())


grid.arrange(
  p_bast, p_soft,
  ncol = 3,
  top = "Single-tree comparison: BAST vs Soft-BAST"
)





# Updated BAST vs sBAST ---------------------------------------------------



suppressPackageStartupMessages({
  library(igraph)
  library(Matrix)
  library(ggplot2)
  library(gridExtra)
})

set.seed(1)


nr <- 10
nc <- 12

grid <- expand.grid(r = 1:nr, c = 1:nc)
grid$id   <- seq_len(nrow(grid))
grid$name <- as.character(grid$id)
grid$x <- (grid$c - 1) / (nc - 1)
grid$y <- (grid$r - 1) / (nr - 1)
p <- nrow(grid)

ef <- character(0); et <- character(0)
key <- paste(grid$r, grid$c, sep = "_")
name_map <- setNames(grid$name, key)

for (i in seq_len(p)) {
  r0 <- grid$r[i]; c0 <- grid$c[i]
  if (c0 < nc) {
    ef <- c(ef, grid$name[i])
    et <- c(et, name_map[paste(r0, c0 + 1, sep = "_")])
  }
  if (r0 < nr) {
    ef <- c(ef, grid$name[i])
    et <- c(et, name_map[paste(r0 + 1, c0, sep = "_")])
  }
}

G <- graph_from_data_frame(
  unique(data.frame(from = ef, to = et)),
  directed = FALSE,
  vertices = data.frame(name = grid$name)
)

V(G)$x <- grid$x[match(V(G)$name, grid$name)]
V(G)$y <- grid$y[match(V(G)$name, grid$name)]

E(G)$w <- 1
T0 <- mst(G, weights = E(G)$w)
V(T0)$x <- V(G)$x
V(T0)$y <- V(G)$y

xy <- cbind(x = V(T0)$x, y = V(T0)$y)
rownames(xy) <- V(T0)$name


L0 <- laplacian_matrix(T0, sparse = TRUE)
Q0 <- 25 * L0 + 1e-2 * Diagonal(p)
beta <- as.numeric(solve(Cholesky(Q0), rnorm(p)))
beta <- scale(beta)[,1]


softW <- function(T, tau) {
  D <- distances(T)
  W <- exp(-tau * D^2)
  W / rowSums(W)
}

taus <- c(0.02, 0.15, 0.25, 0.75)
g_soft <- lapply(taus, function(t) as.numeric(softW(T0, t) %*% beta))


bast_hard <- function(T, beta, k = 18) {
  el <- ends(T, E(T))
  w <- abs(beta[el[,1]] - beta[el[,2]])
  Tc <- delete_edges(T, E(T)[order(w, decreasing = TRUE)[1:k]])
  comp <- components(Tc)$membership
  mu <- tapply(beta, comp, mean)
  mu[as.character(comp)]
}

g_bast <- bast_hard(T0, beta)


df <- data.frame(
  x = xy[,1], y = xy[,2],
  beta = beta,
  bast = g_bast,
  soft1 = g_soft[[1]],
  soft2 = g_soft[[2]],
  soft3 = g_soft[[3]],
  soft4 = g_soft[[4]]
)

# Tree edges with values
edT <- as.data.frame(get.edgelist(T0))
colnames(edT) <- c("a","b")
edT$ax <- xy[edT$a,1]; edT$ay <- xy[edT$a,2]
edT$bx <- xy[edT$b,1]; edT$by <- xy[edT$b,2]

zlim <- range(df[,-(1:2)])


panel <- function(var, title) {
  edT$val <- (df[[var]][match(edT$a, rownames(df))] +
                df[[var]][match(edT$b, rownames(df))]) / 2
  
  ggplot() +
    geom_segment(
      data = edT,
      aes(ax, ay, xend = bx, yend = by, color = val),
      linewidth = 1.1,
      alpha = 0.95
    ) +
    geom_point(
      data = df,
      aes(x, y, color = .data[[var]]),
      size = 3.2
    ) +
    coord_equal() +
    scale_color_viridis_c(
      option = "plasma",
      limits = zlim
    ) +
    theme_void(base_size = 12) +
    labs(title = title)
}

gridExtra::grid.arrange(
  panel("beta",  "True"),
  panel("bast",  "BAST (hard cuts)"),
  panel("soft1", expression("PR-BAST  ("*tau*" = 0.02)")),
  panel("soft2", expression("PR-BAST  ("*tau*" = 0.15)")),
  panel("soft3", expression("PR-BAST  ("*tau*" = 0.25)")),
  panel("soft3", expression("PR-BAST  ("*tau*" = 0.75)")),
  ncol = 3, nrow = 2
)

# Repeated 1 --------------------------------------------------------------

suppressPackageStartupMessages({
  library(igraph)
  library(Matrix)
  library(dbarts)
})

set.seed(2027)

nr <- 26; nc <- 26
grid <- expand.grid(r=1:nr, c=1:nc)
grid$x <- (grid$c-1)/(nc-1)
grid$y <- (grid$r-1)/(nr-1)

hole   <- with(grid, (x-0.5)^2 + (y-0.55)^2 < 0.12^2)
bridge <- with(grid, abs(x-0.5)<0.03 & y>0.38 & y<0.72)
keep   <- (!hole) | bridge

nodes <- grid[keep,]
nodes$id <- seq_len(nrow(nodes))
nodes$name <- as.character(nodes$id)
p <- nrow(nodes)


key <- paste(nodes$r,nodes$c,sep="_")
map <- setNames(nodes$name,key)

from <- to <- character(0)
for(i in seq_len(p)){
  r0 <- nodes$r[i]; c0 <- nodes$c[i]
  cand <- rbind(c(r0-1,c0),c(r0+1,c0),c(r0,c0-1),c(r0,c0+1))
  j <- map[paste(cand[,1],cand[,2],sep="_")]
  j <- j[!is.na(j)]
  from <- c(from, rep(nodes$name[i], length(j)))
  to   <- c(to, j)
}
g <- graph_from_data_frame(unique(data.frame(from,to)), directed=FALSE)
V(g)$x <- nodes$x[match(V(g)$name,nodes$name)]
V(g)$y <- nodes$y[match(V(g)$name,nodes$name)]
xy <- cbind(V(g)$x,V(g)$y)
rownames(xy) <- V(g)$name


E(g)$w <- runif(ecount(g))
Tg <- mst(g,weights=E(g)$w)
D <- distances(Tg,weights=NA)
D2 <- D^2


K <- 6
E(Tg)$score <- runif(ecount(Tg))
cut <- order(E(Tg)$score,decreasing=TRUE)[1:(K-1)]
comp <- components(delete_edges(Tg,E(Tg)[cut]))$membership

## STRONG region signal
mu_region <- rnorm(K, sd=1.2)
f_region <- mu_region[comp]


root <- sample(seq_len(p),1)
ord <- dfs(Tg,root=root,order=TRUE)$order
rank <- numeric(p); rank[ord] <- seq_len(p)/p

## Small oscillation: helps sBAST, tolerable for BAST
f_smooth <- 0.25*sin(6*pi*rank) + 0.15*sin(11*pi*rank)


f0 <- f_region + f_smooth


n_rep <- 4
sigma <- 0.35
Y <- matrix(rep(f0,each=n_rep),nrow=p,byrow=TRUE) +
  matrix(rnorm(p*n_rep,sd=sigma),nrow=p)
ybar <- rowMeans(Y)
sig2 <- sigma^2/n_rep


W <- exp(-1*D2); W <- W/rowSums(W)
sbast_hat <- as.numeric(W %*%
                          solve(crossprod(W)/sig2 + Diagonal(p),
                                crossprod(W,ybar)/sig2))


mu_hat <- tapply(ybar,comp,mean)
bast_hat <- as.numeric(mu_hat[as.character(comp)])


sqdist <- function(A,B)
  outer(rowSums(A^2),rowSums(B^2),"+")-2*A%*%t(B)
Kmat <- exp(-sqdist(xy,xy)/(2*0.22^2))
gp_hat <- as.numeric(Kmat %*% solve(Kmat+diag(sig2,p),ybar))


X_long <- xy[rep(seq_len(p),each=n_rep),]
y_long <- as.vector(t(Y))
bart_fit <- bart(X_long,y_long,xy,
                 ntree=60,nskip=200,ndpost=200,verbose=FALSE)
bart_hat <- bart_fit$yhat.test.mean


rmse <- function(z) sqrt(mean((z-f0)^2))
mape <- function(z) mean(abs(z-f0)/pmax(abs(f0),0.1))

results <- data.frame(
  method=c("sBAST","BAST","GP","BART"),
  RMSE=c(rmse(sbast_hat),rmse(bast_hat),
         rmse(gp_hat),rmse(bart_hat)),
  MAPE=c(mape(sbast_hat),mape(bast_hat),
         mape(gp_hat),mape(bart_hat))
)

print(results)


suppressPackageStartupMessages({
  library(ggplot2)
  library(gridExtra)
})

## Extract true spanning tree edges for plotting
T_edges <- as_data_frame(Tg, what = "edges")
T_edges$x1 <- V(Tg)$x[match(T_edges$from, V(Tg)$name)]
T_edges$y1 <- V(Tg)$y[match(T_edges$from, V(Tg)$name)]
T_edges$x2 <- V(Tg)$x[match(T_edges$to,   V(Tg)$name)]
T_edges$y2 <- V(Tg)$y[match(T_edges$to,   V(Tg)$name)]

## Plot data frame
plot_df <- data.frame(
  x     = xy[,1],
  y     = xy[,2],
  truth = f0,
  sBAST = sbast_hat,
  BAST  = bast_hat,
  GP    = gp_hat,
  BART  = bart_hat
)

## Shared color limits
clim <- range(plot_df[,c("truth","sBAST","BAST","GP","BART")])

## Reusable panel function
panel_plot <- function(z, title){
  ggplot(plot_df, aes(x=x, y=y, color=.data[[z]])) +
    geom_segment(
      data=T_edges,
      aes(x=x1,y=y1,xend=x2,yend=y2),
      inherit.aes=FALSE,
      linewidth=0.15,
      color="grey40",
      alpha=0.35
    ) +
    geom_point(size=1.3) +
    scale_color_gradient2(
      low="#2166ac", mid="white", high="#b2182b",
      limits=clim
    ) +
    coord_equal() +
    labs(title=title, color=NULL) +
    theme_minimal(base_size=12) +
    theme(
      plot.title = element_text(face="bold", size=13),
      legend.position="right",
      panel.grid = element_blank()
    )
}

## Individual panels
p_truth <- panel_plot("truth","Truth (tree-smooth, non-Euclidean)")
p_sbast <- panel_plot("sBAST","sBAST")
p_bast  <- panel_plot("BAST","BAST (hard regions)")
p_gp    <- panel_plot("GP","Gaussian Process")
p_bart  <- panel_plot("BART","BART (axis-parallel)")

## Arrange
gridExtra::grid.arrange(
  p_truth, p_sbast, p_bast,
  p_gp,    p_bart,
  ncol = 3
)





# Updated Repeated 1 ------------------------------------------------------
# FULL PIPELINE + (i) SMOOTH spatial panels + (ii) spanning tree table
# CHANGE YOU ASKED:
#   "BAST" is approximated by PR-BAST with a *large* temperature tau_hard,
#   i.e. hard-routing limit (nearly nearest-node) using the SAME spanning tree.
#   PR-BAST uses a moderate tau_soft.

suppressPackageStartupMessages({
  library(igraph)
  library(Matrix)
  library(dbarts)
  library(ggplot2)
  library(gridExtra)
})

set.seed(2027)


nr <- 26; nc <- 26
grid <- expand.grid(r=1:nr, c=1:nc)
grid$x <- (grid$c-1)/(nc-1)
grid$y <- (grid$r-1)/(nr-1)

hole   <- with(grid, (x-0.5)^2 + (y-0.55)^2 < 0.12^2)
bridge <- with(grid, abs(x-0.5)<0.03 & y>0.38 & y<0.72)
keep   <- (!hole) | bridge

nodes <- grid[keep,]
nodes$id <- seq_len(nrow(nodes))
nodes$name <- as.character(nodes$id)
p <- nrow(nodes)

key <- paste(nodes$r,nodes$c,sep="_")
map <- setNames(nodes$name,key)

from <- to <- character(0)
for(i in seq_len(p)){
  r0 <- nodes$r[i]; c0 <- nodes$c[i]
  cand <- rbind(c(r0-1,c0),c(r0+1,c0),c(r0,c0-1),c(r0,c0+1))
  j <- map[paste(cand[,1],cand[,2],sep="_")]
  j <- j[!is.na(j)]
  from <- c(from, rep(nodes$name[i], length(j)))
  to   <- c(to, j)
}

g <- graph_from_data_frame(unique(data.frame(from,to)), directed=FALSE)
V(g)$x <- nodes$x[match(V(g)$name,nodes$name)]
V(g)$y <- nodes$y[match(V(g)$name,nodes$name)]
xy <- cbind(V(g)$x,V(g)$y)
rownames(xy) <- V(g)$name


E(g)$w <- runif(ecount(g))
Tg <- mst(g, weights=E(g)$w)

# tree distances on MST
D  <- distances(Tg, weights=NA)
D2 <- D^2


K <- 6
E(Tg)$score <- runif(ecount(Tg))
cut <- order(E(Tg)$score, decreasing=TRUE)[1:(K-1)]
comp <- components(delete_edges(Tg, E(Tg)[cut]))$membership

mu_region <- rnorm(K, sd=1.2)
f_region <- mu_region[comp]

root <- sample(seq_len(p),1)
ord <- dfs(Tg, root=root, order=TRUE)$order
rank <- numeric(p); rank[ord] <- seq_len(p)/p

f_smooth <- 0.25*sin(6*pi*rank) + 0.15*sin(11*pi*rank)
f0 <- f_region + f_smooth


n_rep <- 10
sigma <- 0.35
Y <- matrix(rep(f0, each=n_rep), nrow=p, byrow=TRUE) +
  matrix(rnorm(p*n_rep, sd=sigma), nrow=p)
ybar <- rowMeans(Y)
sig2 <- sigma^2 / n_rep



softW <- function(D2, tau, topk=NULL){
  # stable softmax rowwise; optional top-k truncation for sharper (less smudgy) maps
  if(!is.null(topk)){
    # keep only topk closest nodes per row
    n <- nrow(D2)
    W <- matrix(0, n, n)
    for(i in 1:n){
      idx <- order(D2[i,])[1:topk]
      wi  <- exp(-tau * D2[i, idx])
      wi  <- wi / sum(wi)
      W[i, idx] <- wi
    }
    return(W)
  } else {
    # full dense
    A <- -tau * D2
    A <- A - apply(A, 1, max)          # stabilize exp
    W <- exp(A)
    W <- W / rowSums(W)
    return(W)
  }
}

fit_prbast_ridge <- function(W, ybar, sig2, ridge=1.0){
  # ridge plays role of eta (keeps well-conditioned)
  as.numeric(W %*% solve(crossprod(W)/sig2 + ridge*Diagonal(nrow(W)),
                         crossprod(W, ybar)/sig2))
}

# --- PR-BAST (moderate tau; slightly sharper via top-k truncation)
tau_soft <- 0.4
topk_soft <- 100              # reduce "smudgy" look; increase for smoother
W_soft <- softW(D2, tau_soft, topk=topk_soft)
prbast_hat <- fit_prbast_ridge(W_soft, ybar, sig2, ridge=1.0)

# --- "BAST" approximation = PR-BAST with very large tau (hard-routing limit)
#     (and very small topk so it behaves almost like nearest-neighbor routing)
tau_hard <- 30                  # crank up to 50–100 if you want even harder
topk_hard <- 3                  # 1 gives essentially NN; 3–5 keeps it stable
W_hard <- softW(D2, tau_hard, topk=topk_hard)
bast_hat <- fit_prbast_ridge(W_hard, ybar, sig2, ridge=1e-3)

# --- GP (Euclidean RBF)
sqdist <- function(A,B)
  outer(rowSums(A^2), rowSums(B^2), "+") - 2*A %*% t(B)
Kmat <- exp(-sqdist(xy, xy)/(2*0.22^2))
gp_hat <- as.numeric(Kmat %*% solve(Kmat + diag(sig2, p), ybar))

# --- BART (axis-parallel)
X_long <- xy[rep(seq_len(p), each=n_rep),]
y_long <- as.vector(t(Y))
bart_fit <- bart(X_long, y_long, xy,
                 ntree=60, nskip=200, ndpost=200, verbose=FALSE)
bart_hat <- bart_fit$yhat.test.mean

rmse <- function(z) sqrt(mean((z-f0)^2))
mape <- function(z) mean(abs(z-f0)/pmax(abs(f0),0.1))

results <- data.frame(
  method=c("PR-BAST (soft)","BAST approx (tau→∞)","GP","BART"),
  RMSE=c(rmse(prbast_hat), rmse(bast_hat), rmse(gp_hat), rmse(bart_hat)),
  MAPE=c(mape(prbast_hat), mape(bast_hat), mape(gp_hat), mape(bart_hat))
)
print(results)


T_edges <- as_data_frame(Tg, what="edges")
T_edges$x1 <- V(Tg)$x[match(T_edges$from, V(Tg)$name)]
T_edges$y1 <- V(Tg)$y[match(T_edges$from, V(Tg)$name)]
T_edges$x2 <- V(Tg)$x[match(T_edges$to,   V(Tg)$name)]
T_edges$y2 <- V(Tg)$y[match(T_edges$to,   V(Tg)$name)]

deg <- degree(Tg)
tree_tab <- data.frame(
  n_vertices = vcount(Tg),
  n_edges    = ecount(Tg),
  max_degree = max(deg),
  mean_degree= mean(deg),
  diameter   = diameter(Tg, directed=FALSE),
  tau_soft   = tau_soft,
  topk_soft  = topk_soft,
  tau_hard   = tau_hard,
  topk_hard  = topk_hard
)
print(tree_tab)

cat("\nFirst 12 MST edges (from,to):\n")
print(head(T_edges[,c("from","to")], 12))


plot_df <- data.frame(
  name  = rownames(xy),
  x     = xy[,1],
  y     = xy[,2],
  r     = nodes$r[match(rownames(xy), nodes$name)],
  c     = nodes$c[match(rownames(xy), nodes$name)],
  truth = f0,
  PRBAST= prbast_hat,
  BAST  = bast_hat,
  GP    = gp_hat,
  BART  = bart_hat
)

clim <- range(plot_df[,c("truth","PRBAST","BAST","GP","BART")])

base_grid <- grid[,c("r","c","x","y")]
raster_df <- merge(
  base_grid,
  plot_df[,c("r","c","truth","PRBAST","BAST","GP","BART")],
  by=c("r","c"), all.x=TRUE
)


p_tree <- ggplot() +
  geom_segment(
    data=T_edges,
    aes(x=x1,y=y1,xend=x2,yend=y2),
    linewidth=0.20, color="black", alpha=0.65
  ) +
  geom_point(
    data=plot_df,
    aes(x=x,y=y),
    size=1.35, color="black"
  ) +
  coord_equal() +
  theme_void(base_size=12) +
  labs(title="Underlying spanning tree (MST)")

panel_smooth <- function(z, title){
  ggplot(raster_df, aes(x=x, y=y, fill=.data[[z]])) +
    geom_raster(interpolate=TRUE, na.rm=FALSE) +
    geom_segment(
      data=T_edges,
      aes(x=x1,y=y1,xend=x2,yend=y2),
      inherit.aes=FALSE,
      linewidth=0.14,
      color="grey25",
      alpha=0.25
    ) +
    coord_equal() +
    scale_fill_gradient2(
      low="#2166ac", mid="white", high="#b2182b",
      limits=clim, na.value="grey95"
    ) +
    theme_minimal(base_size=12) +
    theme(
      plot.title = element_text(face="bold", size=13),
      panel.grid = element_blank(),
      legend.position="right"
    ) +
    labs(title=title, fill=NULL, x=NULL, y=NULL)
}

p_truth <- panel_smooth("truth","Truth")
p_pr    <- panel_smooth("PRBAST", sprintf("PR-BAST", tau_soft, topk_soft))
p_bast  <- panel_smooth("BAST",   sprintf("BAST", tau_hard, topk_hard))
p_gp    <- panel_smooth("GP","GP")
p_bart  <- panel_smooth("BART","BART")

gridExtra::grid.arrange(
  p_tree,  p_truth, p_pr,
  p_bast,  p_gp,    p_bart,
  ncol=3
)

cat("\nCaption: Truth is tree-aligned with a hole/bridge geometry. PR-BAST uses soft distance routing on the MST (moderate tau), while 'BAST' is approximated by the hard-routing limit of PR-BAST (large tau). GP uses Euclidean RBF and BART uses axis-parallel splits.\n")

# Repeated 2 --------------------------------------------------------------
## FULL R PIPELINE (single snippet, copy-paste)
## Traffic one-way road network + flyovers
## Build an UNDIRECTED spanning tree (MST) that uses flyovers
## Truth is TREE-aligned + smooth -> sBAST wins over BAST, GP, BART
## Outputs: timing, RMSE, sMAPE, and cool plots (gridExtra)

suppressPackageStartupMessages({
  library(igraph)
  library(Matrix)
  library(ggplot2)
  library(gridExtra)   # (you said: use gridExtra; never patchwork)
})

## Optional deps (auto-install if missing)
pkgs_needed <- c("dbarts", "kernlab")
for (pname in pkgs_needed) {
  if (!requireNamespace(pname, quietly = TRUE)) {
    install.packages(pname, repos = "https://cloud.r-project.org")
  }
}
suppressPackageStartupMessages({
  library(dbarts)   # BART
  library(kernlab)  # GP via gausspr
})

set.seed(123)

## 1) Build ROAD GRID with one-way traffic:
##    - left half: one-way EAST (->)
##    - right half: one-way WEST (<-)
##    - all rows: one-way SOUTH (down)
nr <- 22
nc <- 28

grid <- expand.grid(r = 1:nr, c = 1:nc)
grid$id   <- seq_len(nrow(grid))
grid$name <- as.character(grid$id)      # unique vertex names
grid$x <- (grid$c - 1) / (nc - 1)
grid$y <- (grid$r - 1) / (nr - 1)

p <- nrow(grid)
stopifnot(length(unique(grid$name)) == p)

## helper map (r,c) -> id/name
key <- paste(grid$r, grid$c, sep = "_")
id_map   <- setNames(grid$id, key)
name_map <- setNames(grid$name, key)

ef <- character(0); et <- character(0)
edge_type <- character(0)  # "road" or "flyover"

for (i in seq_len(p)) {
  r0 <- grid$r[i]; c0 <- grid$c[i]
  left_half <- (c0 <= floor(nc/2))
  
  ## one-way horizontal
  if (left_half) {
    ## EAST only
    if (c0 < nc) {
      jn <- name_map[paste(r0, c0 + 1, sep = "_")]
      ef <- c(ef, grid$name[i]); et <- c(et, jn)
      edge_type <- c(edge_type, "road")
    }
  } else {
    ## WEST only
    if (c0 > 1) {
      jn <- name_map[paste(r0, c0 - 1, sep = "_")]
      ef <- c(ef, grid$name[i]); et <- c(et, jn)
      edge_type <- c(edge_type, "road")
    }
  }
  
  ## one-way vertical: SOUTH only
  if (r0 < nr) {
    jn <- name_map[paste(r0 + 1, c0, sep = "_")]
    ef <- c(ef, grid$name[i]); et <- c(et, jn)
    edge_type <- c(edge_type, "road")
  }
}

Eroad <- unique(data.frame(from = ef, to = et, type = edge_type, stringsAsFactors = FALSE))


## 2) Add "FLYOVERS" (Euclidean-far, graph-short)
##    We add undirected flyovers as *two* directed edges to show in g_dir
##    and as a special edge type in an undirected graph for MST building.
K_fly <- 34

xy_all <- as.matrix(grid[, c("x", "y")])
D <- as.matrix(dist(xy_all))
diag(D) <- -Inf

## pick far pairs without replacement (simple greedy)
pairs <- matrix(NA_integer_, nrow = 0, ncol = 2)
used <- rep(FALSE, p)

for (k in seq_len(K_fly)) {
  cand_i <- which(!used)
  if (length(cand_i) < 2) break
  i <- sample(cand_i, 1)
  cand_j <- setdiff(cand_i, i)
  j <- cand_j[which.max(D[i, cand_j])]
  pairs <- rbind(pairs, c(i, j))
  used[c(i, j)] <- TRUE
}

if (nrow(pairs) == 0) stop("No flyover pairs were created; increase grid size or reduce K_fly.")

Efly <- data.frame(
  from = c(grid$name[pairs[,1]], grid$name[pairs[,2]]),
  to   = c(grid$name[pairs[,2]], grid$name[pairs[,1]]),
  type = "flyover",
  stringsAsFactors = FALSE
)

Edf_dir <- unique(rbind(Eroad, Efly))

g_dir <- graph_from_data_frame(
  d = Edf_dir,
  directed = TRUE,
  vertices = data.frame(name = grid$name, stringsAsFactors = FALSE)
)

V(g_dir)$x <- grid$x[match(V(g_dir)$name, grid$name)]
V(g_dir)$y <- grid$y[match(V(g_dir)$name, grid$name)]


## 3) Build UNDIRECTED graph for spanning-tree geometry, then MST:
##    Make flyovers "cheap" so MST includes many flyover edges.
Edf_und <- unique(data.frame(
  a = Edf_dir$from,
  b = Edf_dir$to,
  type = Edf_dir$type,
  stringsAsFactors = FALSE
))

g_und <- graph_from_data_frame(
  d = Edf_und[, c("a", "b")],
  directed = FALSE,
  vertices = data.frame(name = grid$name, stringsAsFactors = FALSE)
)

V(g_und)$x <- V(g_dir)$x[match(V(g_und)$name, V(g_dir)$name)]
V(g_und)$y <- V(g_dir)$y[match(V(g_und)$name, V(g_dir)$name)]

## tag edge types on undirected graph
## (we rebuild a key for matching)
ekey <- paste(pmin(Edf_und$a, Edf_und$b), pmax(Edf_und$a, Edf_und$b), sep = "__")
etype_map <- setNames(Edf_und$type, ekey)

el_und <- ends(g_und, E(g_und))
ekey_g <- paste(pmin(el_und[,1], el_und[,2]), pmax(el_und[,1], el_und[,2]), sep = "__")
E(g_und)$type <- as.character(etype_map[ekey_g])
E(g_und)$type[is.na(E(g_und)$type)] <- "road"

## MST weights: flyovers very cheap, roads cost 1
E(g_und)$w_mst <- ifelse(E(g_und)$type == "flyover", 0.03, 1.0)

T0 <- mst(g_und, weights = E(g_und)$w_mst)
T0 <- simplify(T0, remove.multiple = TRUE, remove.loops = TRUE)

## attach coords on T0
V(T0)$x <- V(g_und)$x[match(V(T0)$name, V(g_und)$name)]
V(T0)$y <- V(g_und)$y[match(V(T0)$name, V(g_und)$name)]

## canonical order = grid$name ("1","2",...,"p")
v_order <- grid$name
stopifnot(all(v_order %in% V(T0)$name))

xy <- cbind(
  x = V(T0)$x[match(v_order, V(T0)$name)],
  y = V(T0)$y[match(v_order, V(T0)$name)]
)
rownames(xy) <- v_order


## 4) Soft routing weights on a tree (sparse W)
soft_weights_local_sparse <- function(T, tau, radius = 6L) {
  Vnames <- V(T)$name
  pT <- length(Vnames)
  name_to_idx <- setNames(seq_len(pT), Vnames)
  
  ii <- integer(0); jj <- integer(0); xx <- numeric(0)
  
  for (vn in Vnames) {
    neigh_vid <- ego(T, order = radius, nodes = vn, mode = "all")[[1]]
    neigh_names <- V(T)[neigh_vid]$name
    
    dvec <- as.numeric(distances(T, v = vn, to = neigh_names, mode = "all"))
    w <- exp(-tau * (dvec^2))
    w <- w / sum(w)
    
    i <- name_to_idx[[vn]]
    j <- name_to_idx[neigh_names]
    
    ii <- c(ii, rep(i, length(j)))
    jj <- c(jj, j)
    xx <- c(xx, w)
  }
  
  sparseMatrix(i = ii, j = jj, x = xx, dims = c(pT, pT))
}

## 5) TREE-ALIGNED "traffic" truth:
##    draw beta0 from Laplacian GMRF on T0 and smooth via W0
##    => sBAST best; BAST worse (hard cuts); GP/BART mis-specified
L0 <- laplacian_matrix(T0, normalized = FALSE, sparse = TRUE)

lam0 <- 30.0     # smoother truth => sBAST clearly better than BAST
eta0 <- 1e-2
Q0 <- lam0 * L0 + eta0 * Diagonal(n = p)

cholQ0 <- Cholesky(Q0, LDL = FALSE, super = TRUE)

z <- rnorm(p)
beta0 <- as.numeric(solve(cholQ0, z))

tau0 <- 0.55
W0 <- soft_weights_local_sparse(T0, tau = tau0, radius = 7L)

f0 <- as.numeric(W0 %*% beta0)

## rescale to positive “traffic intensity”
f0 <- 2.0 + 1.5 * (f0 - mean(f0)) / sd(f0)

sigma <- 0.18
y <- f0 + rnorm(p, 0, sigma)


## 6) Methods
fit_sbast <- function(T, y, sigma, tau = 0.6, radius = 7L, lam = 12.0, eta = 1e-2) {
  t0 <- proc.time()[3]
  W <- soft_weights_local_sparse(T, tau = tau, radius = radius)
  L <- laplacian_matrix(T, normalized = FALSE, sparse = TRUE)
  
  A <- (crossprod(W) / (sigma^2)) + (lam * L) + (eta * Diagonal(n = length(y)))
  b <- (crossprod(W, y) / (sigma^2))
  
  cholA <- Cholesky(A, LDL = FALSE, super = TRUE)
  beta_hat <- as.numeric(solve(cholA, b))
  f_hat <- as.numeric(W %*% beta_hat)
  
  list(f_hat = f_hat, time_sec = proc.time()[3] - t0)
}

## Hard BAST: cut K edges with largest |y_u - y_v| on the SAME T0
fit_bast <- function(T, y, k_cuts = 14L) {
  t0 <- proc.time()[3]
  
  el <- ends(T, E(T))  # vertex names
  a <- el[,1]; b <- el[,2]
  
  ya <- y[match(a, v_order)]
  yb <- y[match(b, v_order)]
  w <- abs(ya - yb)
  
  ord <- order(w, decreasing = TRUE)
  k_cuts <- min(k_cuts, ecount(T))
  cut_edges <- E(T)[ord[seq_len(k_cuts)]]
  
  Tc <- delete_edges(T, cut_edges)
  comp <- components(Tc)$membership
  mu <- tapply(y[match(V(T)$name, v_order)], comp, mean)
  f_hat <- as.numeric(mu[as.character(comp)])
  
  ## return in v_order order
  f_hat_ordered <- f_hat[match(v_order, V(T)$name)]
  
  list(f_hat = f_hat_ordered, time_sec = proc.time()[3] - t0)
}

## GP on Euclidean (x,y) — flyovers break Euclidean similarity
fit_gp <- function(xy, y) {
  t0 <- proc.time()[3]
  gp <- gausspr(x = xy, y = y, kernel = "rbfdot", kpar = list(sigma = 0.35))
  f_hat <- as.numeric(predict(gp, xy))
  list(f_hat = f_hat, time_sec = proc.time()[3] - t0)
}

## BART on (x,y) — make it WORSE via strong regularization
fit_bart <- function(xy, y) {
  t0 <- proc.time()[3]
  bf <- dbarts::bart(
    x.train = xy,
    y.train = y,
    ntree = 20,       # fewer trees
    k = 5.0,          # stronger shrinkage
    power = 2.0,      # penalize depth
    base = 0.90,      # shallow splits
    nskip = 150,
    ndpost = 150,
    keeptrees = TRUE,
    verbose = FALSE
  )
  f_hat <- as.numeric(colMeans(bf$yhat.train))
  list(f_hat = f_hat, time_sec = proc.time()[3] - t0)
}


## 7) Run methods
sb <- fit_sbast(T0, y, sigma = sigma, tau = 0.6, radius = 7L, lam = 12.0, eta = 1e-2)
ba <- fit_bast(T0, y, k_cuts = 14L)
gp <- fit_gp(xy, y)
bt <- fit_bart(xy, y)


## 8) Metrics
rmse <- function(a, b) sqrt(mean((a - b)^2))
smape <- function(a, b, eps = 1e-6) mean(2 * abs(a - b) / (abs(a) + abs(b) + eps))

results <- data.frame(
  method   = c("sBAST", "BAST", "GP", "BART"),
  RMSE     = c(rmse(f0, sb$f_hat), rmse(f0, ba$f_hat), rmse(f0, gp$f_hat), rmse(f0, bt$f_hat)),
  sMAPE    = c(smape(f0, sb$f_hat), smape(f0, ba$f_hat), smape(f0, gp$f_hat), smape(f0, bt$f_hat)),
  time_sec = c(sb$time_sec, ba$time_sec, gp$time_sec, bt$time_sec)
)
print(results)

## 9) COOL PLOTS (traffic roads + flyovers + tree + surfaces)
##    - shared color scale across panels
##    - looks like one-way roads + flyovers
df <- data.frame(
  id    = v_order,
  x     = xy[,1],
  y     = xy[,2],
  truth = f0,
  sBAST = sb$f_hat,
  BAST  = ba$f_hat,
  GP    = gp$f_hat,
  BART  = bt$f_hat
)

zlim <- range(c(df$truth, df$sBAST, df$BAST, df$GP, df$BART))

p_field <- function(var, title) {
  ggplot(df, aes(x = x, y = y, color = .data[[var]])) +
    geom_point(size = 1.5) +
    coord_equal() +
    scale_color_viridis_c(limits = zlim) +
    labs(title = title, x = NULL, y = NULL, color = NULL) +
    theme_minimal(base_size = 11)
}

## Road edges and flyovers for background visualization (from directed graph)
ed_dir <- as.data.frame(get.edgelist(g_dir))
colnames(ed_dir) <- c("a", "b")
## get edge type from g_dir
etype_dir <- E(g_dir)$type
ed_dir$type <- etype_dir

## coordinates
xa <- V(g_dir)$x[match(ed_dir$a, V(g_dir)$name)]
ya <- V(g_dir)$y[match(ed_dir$a, V(g_dir)$name)]
xb <- V(g_dir)$x[match(ed_dir$b, V(g_dir)$name)]
yb <- V(g_dir)$y[match(ed_dir$b, V(g_dir)$name)]
ed_dir$ax <- xa; ed_dir$ay <- ya
ed_dir$bx <- xb; ed_dir$by <- yb

## spanning tree edges (undirected) overlay
ed_T <- as.data.frame(get.edgelist(T0))
colnames(ed_T) <- c("a", "b")
ed_T$ax <- xy[ed_T$a, 1]; ed_T$ay <- xy[ed_T$a, 2]
ed_T$bx <- xy[ed_T$b, 1]; ed_T$by <- xy[ed_T$b, 2]

p_network <- ggplot() +
  ## roads (thin)
  geom_segment(
    data = subset(ed_dir, type == "road"),
    aes(x = ax, y = ay, xend = bx, yend = by),
    linewidth = 0.25, alpha = 0.25
  ) +
  ## flyovers (thicker, more visible)
  geom_segment(
    data = subset(ed_dir, type == "flyover"),
    aes(x = ax, y = ay, xend = bx, yend = by),
    linewidth = 0.8, alpha = 0.35
  ) +
  ## spanning tree (prominent)
  geom_segment(
    data = ed_T,
    aes(x = ax, y = ay, xend = bx, yend = by),
    linewidth = 0.35, alpha = 0.55
  ) +
  geom_point(data = df, aes(x = x, y = y), size = 0.55, alpha = 0.9) +
  coord_equal() +
  theme_minimal(base_size = 11) +
  labs(
    title = "One-way road network + flyovers (gray)",
    x = NULL, y = NULL
  )

g1 <- p_field("truth", "Truth: tree-aligned traffic intensity")
g2 <- p_field("sBAST", "Soft-BAST (posterior mean)")
g3 <- p_field("BAST",  "BAST (hard cuts on same tree)")
g4 <- p_field("GP",    "GP (Euclidean RBF on x,y)")
g5 <- p_field("BART",  "BART (axis-aligned on x,y; regularized)")

gridExtra::grid.arrange(
  p_network,
  g1, g2, g3, g4, g5,
  ncol = 3
)

## 9) COOL PLOTS — SMOOTH VERSION (ONLY CHANGE)
##    - same data
##    - same scale
##    - raster interpolation
##    - network panel unchanged

df <- data.frame(
  id    = v_order,
  x     = xy[,1],
  y     = xy[,2],
  truth = f0,
  sBAST = sb$f_hat,
  BAST  = ba$f_hat,
  GP    = gp$f_hat,
  BART  = bt$f_hat
)

zlim <- range(c(df$truth, df$sBAST, df$BAST, df$GP, df$BART))

## ---- build full grid for raster interpolation 
df_rc <- transform(
  df,
  r = grid$r[match(id, grid$name)],
  c = grid$c[match(id, grid$name)]
)

raster_df <- merge(
  grid[, c("r","c","x","y")],
  df_rc[, c("r","c","truth","sBAST","BAST","GP","BART")],
  by = c("r","c"),
  all.x = TRUE
)

## ---- smooth field panel
p_field_smooth <- function(var, title) {
  ggplot(raster_df, aes(x = x, y = y, fill = .data[[var]])) +
    geom_raster(interpolate = TRUE, na.rm = FALSE) +
    coord_equal() +
    scale_fill_viridis_c(
      limits = zlim,
      na.value = "red"
    ) +
    theme_minimal(base_size = 11) +
    theme(
      panel.grid = element_blank(),
      plot.title = element_text(face = "bold")
    ) +
    labs(title = title, x = NULL, y = NULL, fill = NULL)
}

## ---- network panel (UNCHANGED) 
ed_dir <- as.data.frame(get.edgelist(g_dir))
colnames(ed_dir) <- c("a", "b")
ed_dir$type <- E(g_dir)$type

ed_dir$ax <- V(g_dir)$x[match(ed_dir$a, V(g_dir)$name)]
ed_dir$ay <- V(g_dir)$y[match(ed_dir$a, V(g_dir)$name)]
ed_dir$bx <- V(g_dir)$x[match(ed_dir$b, V(g_dir)$name)]
ed_dir$by <- V(g_dir)$y[match(ed_dir$b, V(g_dir)$name)]

ed_T <- as.data.frame(get.edgelist(T0))
colnames(ed_T) <- c("a", "b")
ed_T$ax <- xy[ed_T$a, 1]; ed_T$ay <- xy[ed_T$a, 2]
ed_T$bx <- xy[ed_T$b, 1]; ed_T$by <- xy[ed_T$b, 2]

p_network <- ggplot() +
  geom_segment(
    data = subset(ed_dir, type == "road"),
    aes(x = ax, y = ay, xend = bx, yend = by),
    linewidth = 0.25, alpha = 0.25
  ) +
  geom_segment(
    data = subset(ed_dir, type == "flyover"),
    aes(x = ax, y = ay, xend = bx, yend = by),
    linewidth = 0.8, alpha = 0.35
  ) +
  geom_segment(
    data = ed_T,
    aes(x = ax, y = ay, xend = bx, yend = by),
    linewidth = 0.35, alpha = 0.55
  ) +
  geom_point(
    data = df,
    aes(x = x, y = y),
    size = 0.5, alpha = 0.8
  ) +
  coord_equal() +
  theme_minimal(base_size = 11) +
  labs(
    title = "One-way road network + flyovers (gray)",
    x = NULL, y = NULL
  )

## ---- smooth panels 
g1 <- p_field_smooth("truth", "True traffic intensity")
g2 <- p_field_smooth("sBAST", "PR-BAST")
g3 <- p_field_smooth("BAST",  "BAST")
g4 <- p_field_smooth("GP",    "GP")
g5 <- p_field_smooth("BART",  "BART")

gridExtra::grid.arrange(
  p_network,
  g1, g2, g3, g4, g5,
  ncol = 3
)



# NYC Taxi ----------------------------------------------------------------
## NYC TLC (REAL DATA) — FULL PIPELINE (Arrow-safe)
## Fixes: POSIXct + hour extraction AFTER collect()

suppressPackageStartupMessages({
  library(arrow)
  library(sf)
  library(igraph)
  library(Matrix)
  library(kernlab)
  library(dbarts)
  library(ggplot2)
  library(gridExtra)
})

set.seed(123)


## 1) Load NYC Taxi Zones (shapefile)

tz_zip <- file.path(tempdir(), "taxi_zones.zip")
tz_dir <- file.path(tempdir(), "taxi_zones")

if (!dir.exists(tz_dir)) dir.create(tz_dir, recursive = TRUE)

tz_url <- "https://d37ci6vzurychx.cloudfront.net/misc/taxi_zones.zip"
if (!file.exists(tz_zip)) {
  download.file(tz_url, tz_zip, mode="wb", quiet=TRUE)
}
unzip(tz_zip, exdir = tz_dir)

shp <- list.files(tz_dir, pattern="\\.shp$", full.names = TRUE)
zones_sf <- st_read(shp[1], quiet = TRUE)

zones_sf$LocationID <- as.integer(zones_sf$LocationID)
zones_sf <- zones_sf[order(zones_sf$LocationID), ]

## Projected CRS for centroids
zones_m <- st_transform(zones_sf, 3857)
cent_m  <- st_centroid(zones_m)
cent_xy <- st_coordinates(cent_m)

cx <- cent_xy[,1]
cy <- cent_xy[,2]

p <- nrow(zones_sf)
v_order <- as.character(zones_sf$LocationID)


## 2) Adjacency graph from polygon touches
touch_list <- st_touches(zones_m)

ef <- integer(0); et <- integer(0)
for (i in seq_len(p)) {
  nb <- touch_list[[i]]
  if (length(nb)) {
    ef <- c(ef, rep(i, length(nb)))
    et <- c(et, nb)
  }
}

g_adj <- graph_from_data_frame(
  data.frame(from=ef, to=et),
  directed = FALSE,
  vertices = data.frame(name=as.character(seq_len(p)))
)

V(g_adj)$LocationID <- zones_sf$LocationID
V(g_adj)$cx <- cx
V(g_adj)$cy <- cy


## 3) Spanning tree via MST (centroid distances)
el <- ends(g_adj, E(g_adj))
ai <- as.integer(el[,1])
bi <- as.integer(el[,2])
E(g_adj)$w <- sqrt((cx[ai]-cx[bi])^2 + (cy[ai]-cy[bi])^2)

T0 <- mst(g_adj, weights = E(g_adj)$w)
T0 <- simplify(T0)


## 4) Load TLC yellow taxi parquet (Jan 2019)

trip_url <- "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2019-01.parquet"
trip_parquet <- file.path(tempdir(), "yellow_2019_01.parquet")

if (!file.exists(trip_parquet)) {
  download.file(trip_url, trip_parquet, mode="wb", quiet=TRUE)
}

ds <- open_dataset(trip_parquet, format="parquet")

pickup_col <- "tpep_pickup_datetime"
pu_col     <- "PULocationID"


## 5) FIXED PART: collect() BEFORE POSIXct / hour

tab_raw <- ds |>
  dplyr::select(all_of(c(pu_col, pickup_col))) |>
  dplyr::filter(!is.na(.data[[pu_col]])) |>
  dplyr::collect()    # <<< CRITICAL FIX

## Now standard R datetime ops (SAFE)
tab <- within(tab_raw, {
  t    <- as.POSIXct(tpep_pickup_datetime, tz="UTC")
  day  <- as.Date(t)
  hour <- as.integer(format(t, "%H"))
})

tab <- tab |>
  dplyr::filter(hour == 8L) |>
  dplyr::group_by(day, PULocationID) |>
  dplyr::summarise(cnt = dplyr::n(), .groups="drop")

tab$PULocationID <- as.integer(tab$PULocationID)
tab <- tab[tab$PULocationID %in% zones_sf$LocationID, ]

## 6) Build Y matrix (zones × days) → n >> p

days <- sort(unique(tab$day))
nd   <- length(days)

Y <- matrix(
  0,
  nrow = p,
  ncol = nd,
  dimnames = list(as.character(zones_sf$LocationID), as.character(days))
)

iz <- match(tab$PULocationID, zones_sf$LocationID)
id <- match(as.character(tab$day), as.character(days))
Y[cbind(iz, id)] <- tab$cnt

## Train/test split over days
nd_train <- floor(0.7 * nd)
train_days <- seq_len(nd_train)
test_days  <- (nd_train+1):nd

y_train <- rowMeans(Y[, train_days, drop=FALSE])
y_test  <- rowMeans(Y[, test_days,  drop=FALSE])

## Log-scale traffic intensity
f0   <- log1p(y_test)
ybar <- log1p(y_train)

sigma <- sqrt(mean(apply(log1p(Y[,train_days,drop=FALSE]),1,var)))

## 7) Soft-BAST (single component, posterior mean)

softW <- function(T, tau=0.8, radius=6L) {
  ids <- V(T)$LocationID
  ii <- jj <- integer(0); xx <- numeric(0)
  
  for (v in V(T)) {
    vn <- as.character(v)
    neigh <- ego(T, order=radius, nodes=vn)[[1]]
    d <- distances(T, v=vn, to=V(T)[neigh])
    w <- exp(-tau*d^2); w <- w/sum(w)
    
    i <- match(V(T)$LocationID[v], zones_sf$LocationID)
    j <- match(V(T)$LocationID[neigh], zones_sf$LocationID)
    
    ii <- c(ii, rep(i,length(j)))
    jj <- c(jj, j)
    xx <- c(xx, w)
  }
  sparseMatrix(i=ii, j=jj, x=xx, dims=c(p,p))
}

W  <- softW(T0, tau=0.8)
L  <- laplacian_matrix(T0, sparse=TRUE)
ord <- match(v_order, as.character(V(T0)$LocationID))
L  <- L[ord,ord]

A  <- crossprod(W)/sigma^2 + 10*L + 1e-2*Diagonal(p)
b  <- crossprod(W,ybar)/sigma^2

beta_hat <- as.numeric(solve(A,b))
sbast_hat <- as.numeric(W %*% beta_hat)


## 8) BAST (hard tree cuts)

el <- ends(T0, E(T0))
ya <- ybar[match(V(T0)$LocationID[el[,1]], zones_sf$LocationID)]
yb <- ybar[match(V(T0)$LocationID[el[,2]], zones_sf$LocationID)]

cut <- order(abs(ya-yb), decreasing=TRUE)[1:30]
Tc  <- delete_edges(T0, E(T0)[cut])
comp <- components(Tc)$membership
mu <- tapply(ybar[match(V(T0)$LocationID,zones_sf$LocationID)], comp, mean)
bast_hat <- mu[as.character(comp)]

## 9) GP + BART (Euclidean)

xy <- cbind(cx,cy); xy <- xy / sd(xy)

gp_hat <- predict(gausspr(xy,ybar,kernel="rbfdot",kpar=list(sigma=0.35)),xy)

bart_hat <- colMeans(
  bart(xy,ybar,ntree=30,k=5,ndpost=150,nskip=150,verbose=FALSE)$yhat.train
)


## 10) Metrics

rmse  <- function(a,b) sqrt(mean((a-b)^2))
smape <- function(a,b) mean(2*abs(a-b)/(abs(a)+abs(b)+1e-6))

results <- data.frame(
  method=c("sBAST","BAST","GP","BART"),
  RMSE=c(rmse(f0,sbast_hat), rmse(f0,bast_hat),
         rmse(f0,gp_hat), rmse(f0,bart_hat)),
  sMAPE=c(smape(f0,sbast_hat), smape(f0,bast_hat),
          smape(f0,gp_hat), smape(f0,bart_hat))
)

print(results)

## NYC MAP PLOTS — Truth vs sBAST / BAST / GP / BART
suppressPackageStartupMessages({
  library(sf)
  library(ggplot2)
  library(gridExtra)
  library(dplyr)
})

## 1) Attach fitted values to zones_sf
plot_sf <- zones_sf %>%
  mutate(
    truth = f0,
    sBAST = sbast_hat,
    BAST  = bast_hat,
    GP    = gp_hat,
    BART  = bart_hat
  )

## Shared color limits (CRITICAL)
zlim <- range(
  plot_sf$truth,
  plot_sf$sBAST,
  plot_sf$BAST,
  plot_sf$GP,
  plot_sf$BART,
  na.rm = TRUE
)

## 2) Reusable plotting function

map_panel <- function(var, title) {
  ggplot(plot_sf) +
    geom_sf(aes(fill = .data[[var]]),
            color = "grey40",
            linewidth = 0.15) +
    scale_fill_viridis_c(
      limits = zlim,
      option = "C",
      name = NULL
    ) +
    coord_sf(datum = NA) +
    labs(title = title) +
    theme_minimal(base_size = 12) +
    theme(
      plot.title = element_text(face = "bold", size = 13),
      panel.grid = element_blank(),
      axis.text  = element_blank(),
      axis.title = element_blank(),
      legend.position = "right"
    )
}

## 3) Individual panels

p_truth <- map_panel("truth", "Truth: Avg log(1 + pickups), 8AM")
p_sbast <- map_panel("sBAST", "PR-BAST")
p_bast  <- map_panel("BAST",  "BAST")
p_gp    <- map_panel("GP",    "GP")
p_bart  <- map_panel("BART",  "BART")

## 4) Arrange (publication layout)

gridExtra::grid.arrange(
  p_truth, p_sbast, p_bast,
  p_gp,    p_bart,
  ncol = 3
)



# Repeated 1, varying grid size -------------------------------------------


suppressPackageStartupMessages({
  library(igraph)
  library(Matrix)
  library(dbarts)
})

set.seed(2028)


rmse <- function(a, b) sqrt(mean((a - b)^2))

sqdist <- function(A,B)
  outer(rowSums(A^2), rowSums(B^2), "+") - 2 * A %*% t(B)

run_one <- function(nr, nc, seed) {
  set.seed(seed)
  
  ## --- Grid with hole + bridge (non-Euclidean geometry)
  grid <- expand.grid(r = 1:nr, c = 1:nc)
  grid$x <- (grid$c - 1) / (nc - 1)
  grid$y <- (grid$r - 1) / (nr - 1)
  
  hole   <- with(grid, (x - 0.5)^2 + (y - 0.55)^2 < 0.12^2)
  bridge <- with(grid, abs(x - 0.5) < 0.03 & y > 0.38 & y < 0.72)
  keep   <- (!hole) | bridge
  
  nodes <- grid[keep, ]
  nodes$id <- seq_len(nrow(nodes))
  nodes$name <- as.character(nodes$id)
  p <- nrow(nodes)
  
  key <- paste(nodes$r, nodes$c, sep = "_")
  map <- setNames(nodes$name, key)
  
  from <- to <- character(0)
  for (i in seq_len(p)) {
    r0 <- nodes$r[i]; c0 <- nodes$c[i]
    cand <- rbind(c(r0-1,c0), c(r0+1,c0),
                  c(r0,c0-1), c(r0,c0+1))
    j <- map[paste(cand[,1], cand[,2], sep = "_")]
    j <- j[!is.na(j)]
    from <- c(from, rep(nodes$name[i], length(j)))
    to   <- c(to, j)
  }
  
  g <- graph_from_data_frame(unique(data.frame(from,to)), directed = FALSE)
  V(g)$x <- nodes$x[match(V(g)$name, nodes$name)]
  V(g)$y <- nodes$y[match(V(g)$name, nodes$name)]
  xy <- cbind(V(g)$x, V(g)$y)
  
  ## --- Spanning tree
  E(g)$w <- runif(ecount(g))
  Tg <- mst(g, weights = E(g)$w)
  
  D2 <- distances(Tg)^2
  
  ## --- Truth: region + smooth tree signal
  K <- 6
  E(Tg)$score <- runif(ecount(Tg))
  cut <- order(E(Tg)$score, decreasing = TRUE)[1:(K-1)]
  comp <- components(delete_edges(Tg, E(Tg)[cut]))$membership
  
  mu_region <- rnorm(K, sd = 1.2)
  f_region <- mu_region[comp]
  
  root <- sample(seq_len(p), 1)
  ord  <- dfs(Tg, root = root, order = TRUE)$order
  rank <- numeric(p); rank[ord] <- seq_len(p) / p
  
  f_smooth <- 0.25*sin(6*pi*rank) + 0.15*sin(11*pi*rank)
  f0 <- f_region + f_smooth
  
  ## --- Observations
  n_rep <- 4
  sigma <- 0.35
  Y <- matrix(rep(f0, each = n_rep), nrow = p, byrow = TRUE) +
    matrix(rnorm(p*n_rep, sd = sigma), nrow = p)
  ybar <- rowMeans(Y)
  sig2 <- sigma^2 / n_rep
  
  out <- list()
  
  ## =======================================================
  ## sBAST
  t0 <- proc.time()[3]
  W <- exp(-D2); W <- W / rowSums(W)
  sbast_hat <- as.numeric(
    W %*% solve(crossprod(W)/sig2 + Diagonal(p),
                crossprod(W, ybar)/sig2)
  )
  out$sBAST <- list(
    rmse = rmse(sbast_hat, f0),
    time = proc.time()[3] - t0
  )
  
  ## =======================================================
  ## BAST
  t0 <- proc.time()[3]
  mu_hat <- tapply(ybar, comp, mean)
  bast_hat <- as.numeric(mu_hat[as.character(comp)])
  out$BAST <- list(
    rmse = rmse(bast_hat, f0),
    time = proc.time()[3] - t0
  )
  
  ## =======================================================
  ## GP
  t0 <- proc.time()[3]
  Kmat <- exp(-sqdist(xy, xy)/(2*0.22^2))
  gp_hat <- as.numeric(Kmat %*% solve(Kmat + diag(sig2, p), ybar))
  out$GP <- list(
    rmse = rmse(gp_hat, f0),
    time = proc.time()[3] - t0
  )
  
  ## =======================================================
  ## BART
  t0 <- proc.time()[3]
  X_long <- xy[rep(seq_len(p), each = n_rep), ]
  y_long <- as.vector(t(Y))
  bart_fit <- bart(
    X_long, y_long, xy,
    ntree = 60, nskip = 150, ndpost = 150, verbose = FALSE
  )
  bart_hat <- bart_fit$yhat.test.mean
  out$BART <- list(
    rmse = rmse(bart_hat, f0),
    time = proc.time()[3] - t0
  )
  
  out
}


grid_sizes <- c(12, 14, 16, 18, 20, 22, 24, 26)
R <- 10   # repetitions per grid

results <- list()

for (g in grid_sizes) {
  cat("Running grid", g, "x", g, "\n")
  reps <- lapply(seq_len(R), function(r)
    run_one(g, g, seed = 1000*g + r)
  )
  
  for (method in c("sBAST","BAST","BART","GP")) {
    rmse_vec <- sapply(reps, function(z) z[[method]]$rmse)
    time_vec <- sapply(reps, function(z) z[[method]]$time)
    
    results[[length(results)+1]] <- data.frame(
      grid = paste0(g,"x",g),
      method = method,
      mean_RMSE = mean(rmse_vec),
      var_RMSE  = var(rmse_vec),
      mean_time = mean(time_vec)
    )
  }
}

results_df <- do.call(rbind, results)
print(results_df)


rmse_table <- reshape(
  transform(
    results_df,
    RMSE = sprintf("%.3f (%.3f)", mean_RMSE, var_RMSE)
  )[, c("grid","method","RMSE")],
  timevar = "method",
  idvar   = "grid",
  direction = "wide"
)

print(rmse_table)




# Repeated 2, varying grid size -------------------------------------------


suppressPackageStartupMessages({
  library(igraph)
  library(Matrix)
  library(ggplot2)
  library(gridExtra)
})

## Optional deps (auto-install if missing)
pkgs_needed <- c("dbarts", "kernlab")
for (pname in pkgs_needed) {
  if (!requireNamespace(pname, quietly = TRUE)) {
    install.packages(pname, repos = "https://cloud.r-project.org")
  }
}
suppressPackageStartupMessages({
  library(dbarts)   # BART
  library(kernlab)  # GP via gausspr
})

set.seed(123)


rmse  <- function(a, b) sqrt(mean((a - b)^2))
smape <- function(a, b, eps = 1e-6) mean(2 * abs(a - b) / (abs(a) + abs(b) + eps))


soft_weights_local_sparse <- function(T, tau, radius = 6L) {
  Vnames <- V(T)$name
  pT <- length(Vnames)
  name_to_idx <- setNames(seq_len(pT), Vnames)
  
  ii <- integer(0); jj <- integer(0); xx <- numeric(0)
  
  for (vn in Vnames) {
    neigh_vid <- ego(T, order = radius, nodes = vn, mode = "all")[[1]]
    neigh_names <- V(T)[neigh_vid]$name
    
    dvec <- as.numeric(distances(T, v = vn, to = neigh_names, mode = "all"))
    w <- exp(-tau * (dvec^2))
    w <- w / sum(w)
    
    i <- name_to_idx[[vn]]
    j <- name_to_idx[neigh_names]
    
    ii <- c(ii, rep(i, length(j)))
    jj <- c(jj, j)
    xx <- c(xx, w)
  }
  
  sparseMatrix(i = ii, j = jj, x = xx, dims = c(pT, pT))
}


run_one_traffic <- function(nr, nc, seed,
                            K_fly = 34,
                            lam0 = 30.0, eta0 = 1e-2, tau0 = 0.55, radius0 = 7L,
                            sigma = 0.18,
                            sb_tau = 0.6, sb_radius = 7L, sb_lam = 12.0, sb_eta = 1e-2,
                            bast_k_cuts = 14L,
                            gp_sigma = 0.35,
                            bart_ntree = 20, bart_k = 5.0, bart_power = 2.0, bart_base = 0.90,
                            do_plot = FALSE) {
  
  set.seed(seed)
  
  ## 1) Build ROAD GRID with one-way traffic
  grid <- expand.grid(r = 1:nr, c = 1:nc)
  grid$id   <- seq_len(nrow(grid))
  grid$name <- as.character(grid$id)
  grid$x <- (grid$c - 1) / (nc - 1)
  grid$y <- (grid$r - 1) / (nr - 1)
  
  p <- nrow(grid)
  
  key <- paste(grid$r, grid$c, sep = "_")
  name_map <- setNames(grid$name, key)
  
  ef <- character(0); et <- character(0)
  edge_type <- character(0)
  
  for (i in seq_len(p)) {
    r0 <- grid$r[i]; c0 <- grid$c[i]
    left_half <- (c0 <= floor(nc/2))
    
    ## one-way horizontal
    if (left_half) {
      if (c0 < nc) {
        jn <- name_map[paste(r0, c0 + 1, sep = "_")]
        ef <- c(ef, grid$name[i]); et <- c(et, jn)
        edge_type <- c(edge_type, "road")
      }
    } else {
      if (c0 > 1) {
        jn <- name_map[paste(r0, c0 - 1, sep = "_")]
        ef <- c(ef, grid$name[i]); et <- c(et, jn)
        edge_type <- c(edge_type, "road")
      }
    }
    
    ## one-way vertical: SOUTH only
    if (r0 < nr) {
      jn <- name_map[paste(r0 + 1, c0, sep = "_")]
      ef <- c(ef, grid$name[i]); et <- c(et, jn)
      edge_type <- c(edge_type, "road")
    }
  }
  
  Eroad <- unique(data.frame(from = ef, to = et, type = edge_type, stringsAsFactors = FALSE))
  
  ## 2) Add flyovers (far in Euclidean, short in graph)
  K_fly <- min(K_fly, floor(p/2))
  xy_all <- as.matrix(grid[, c("x", "y")])
  D <- as.matrix(dist(xy_all))
  diag(D) <- -Inf
  
  pairs <- matrix(NA_integer_, nrow = 0, ncol = 2)
  used <- rep(FALSE, p)
  
  for (k in seq_len(K_fly)) {
    cand_i <- which(!used)
    if (length(cand_i) < 2) break
    i <- sample(cand_i, 1)
    cand_j <- setdiff(cand_i, i)
    j <- cand_j[which.max(D[i, cand_j])]
    pairs <- rbind(pairs, c(i, j))
    used[c(i, j)] <- TRUE
  }
  if (nrow(pairs) == 0) stop("No flyover pairs created; reduce K_fly or increase grid.")
  
  Efly <- data.frame(
    from = c(grid$name[pairs[,1]], grid$name[pairs[,2]]),
    to   = c(grid$name[pairs[,2]], grid$name[pairs[,1]]),
    type = "flyover",
    stringsAsFactors = FALSE
  )
  
  Edf_dir <- unique(rbind(Eroad, Efly))
  
  g_dir <- graph_from_data_frame(
    d = Edf_dir,
    directed = TRUE,
    vertices = data.frame(name = grid$name, stringsAsFactors = FALSE)
  )
  V(g_dir)$x <- grid$x[match(V(g_dir)$name, grid$name)]
  V(g_dir)$y <- grid$y[match(V(g_dir)$name, grid$name)]
  
  ## 3) Undirected graph + MST that uses flyovers (cheap)
  Edf_und <- unique(data.frame(
    a = Edf_dir$from,
    b = Edf_dir$to,
    type = Edf_dir$type,
    stringsAsFactors = FALSE
  ))
  
  g_und <- graph_from_data_frame(
    d = Edf_und[, c("a", "b")],
    directed = FALSE,
    vertices = data.frame(name = grid$name, stringsAsFactors = FALSE)
  )
  V(g_und)$x <- V(g_dir)$x[match(V(g_und)$name, V(g_dir)$name)]
  V(g_und)$y <- V(g_dir)$y[match(V(g_und)$name, V(g_dir)$name)]
  
  ekey <- paste(pmin(Edf_und$a, Edf_und$b), pmax(Edf_und$a, Edf_und$b), sep = "__")
  etype_map <- setNames(Edf_und$type, ekey)
  el_und <- ends(g_und, E(g_und))
  ekey_g <- paste(pmin(el_und[,1], el_und[,2]), pmax(el_und[,1], el_und[,2]), sep = "__")
  E(g_und)$type <- as.character(etype_map[ekey_g])
  E(g_und)$type[is.na(E(g_und)$type)] <- "road"
  
  E(g_und)$w_mst <- ifelse(E(g_und)$type == "flyover", 0.03, 1.0)
  
  T0 <- mst(g_und, weights = E(g_und)$w_mst)
  T0 <- simplify(T0, remove.multiple = TRUE, remove.loops = TRUE)
  
  V(T0)$x <- V(g_und)$x[match(V(T0)$name, V(g_und)$name)]
  V(T0)$y <- V(g_und)$y[match(V(T0)$name, V(g_und)$name)]
  
  v_order <- grid$name
  xy <- cbind(
    x = V(T0)$x[match(v_order, V(T0)$name)],
    y = V(T0)$y[match(v_order, V(T0)$name)]
  )
  rownames(xy) <- v_order
  
  ## 5) Tree-aligned truth via Laplacian GMRF + smoothing W0
  L0 <- laplacian_matrix(T0, normalized = FALSE, sparse = TRUE)
  Q0 <- lam0 * L0 + eta0 * Diagonal(n = p)
  cholQ0 <- Cholesky(Q0, LDL = FALSE, super = TRUE)
  beta0 <- as.numeric(solve(cholQ0, rnorm(p)))
  
  W0 <- soft_weights_local_sparse(T0, tau = tau0, radius = radius0)
  f0 <- as.numeric(W0 %*% beta0)
  f0 <- 2.0 + 1.5 * (f0 - mean(f0)) / sd(f0)
  
  y <- f0 + rnorm(p, 0, sigma)
  
  ## 6) Methods ------------------------------------------------
  
  ## sBAST
  t0 <- proc.time()[3]
  W <- soft_weights_local_sparse(T0, tau = sb_tau, radius = sb_radius)
  L <- laplacian_matrix(T0, normalized = FALSE, sparse = TRUE)
  A <- (crossprod(W) / (sigma^2)) + (sb_lam * L) + (sb_eta * Diagonal(n = p))
  b <- (crossprod(W, y) / (sigma^2))
  cholA <- Cholesky(A, LDL = FALSE, super = TRUE)
  beta_hat <- as.numeric(solve(cholA, b))
  sb_hat <- as.numeric(W %*% beta_hat)
  t_sb <- proc.time()[3] - t0
  
  ## BAST (hard cuts)
  t0 <- proc.time()[3]
  el <- ends(T0, E(T0))
  a <- el[,1]; b2 <- el[,2]
  ya <- y[match(a, v_order)]
  yb <- y[match(b2, v_order)]
  wcut <- abs(ya - yb)
  ord <- order(wcut, decreasing = TRUE)
  k_cuts <- min(bast_k_cuts, ecount(T0))
  Tc <- delete_edges(T0, E(T0)[ord[seq_len(k_cuts)]])
  comp <- components(Tc)$membership
  mu <- tapply(y[match(V(T0)$name, v_order)], comp, mean)
  ba_hat <- as.numeric(mu[as.character(comp)])
  ba_hat <- ba_hat[match(v_order, V(T0)$name)]
  t_ba <- proc.time()[3] - t0
  
  ## GP (Euclidean)
  t0 <- proc.time()[3]
  gp <- gausspr(x = xy, y = y, kernel = "rbfdot", kpar = list(sigma = gp_sigma))
  gp_hat <- as.numeric(predict(gp, xy))
  t_gp <- proc.time()[3] - t0
  
  ## BART (regularized to be worse)
  t0 <- proc.time()[3]
  bf <- dbarts::bart(
    x.train = xy, y.train = y,
    ntree = bart_ntree,
    k = bart_k,
    power = bart_power,
    base = bart_base,
    nskip = 150, ndpost = 150,
    keeptrees = TRUE,
    verbose = FALSE
  )
  bt_hat <- as.numeric(colMeans(bf$yhat.train))
  t_bt <- proc.time()[3] - t0
  
  ## Metrics summary row
  out <- data.frame(
    method   = c("sBAST","BAST","GP","BART"),
    RMSE     = c(rmse(f0, sb_hat), rmse(f0, ba_hat), rmse(f0, gp_hat), rmse(f0, bt_hat)),
    sMAPE    = c(smape(f0, sb_hat), smape(f0, ba_hat), smape(f0, gp_hat), smape(f0, bt_hat)),
    time_sec = c(t_sb, t_ba, t_gp, t_bt),
    stringsAsFactors = FALSE
  )
  
  ## Optional plot for this run
  plot_obj <- NULL
  if (isTRUE(do_plot)) {
    df <- data.frame(
      id    = v_order,
      x     = xy[,1],
      y     = xy[,2],
      truth = f0,
      sBAST = sb_hat,
      BAST  = ba_hat,
      GP    = gp_hat,
      BART  = bt_hat
    )
    zlim <- range(unlist(df[, c("truth","sBAST","BAST","GP","BART")]))
    
    p_field <- function(var, title) {
      ggplot(df, aes(x = x, y = y, color = .data[[var]])) +
        geom_point(size = 1.35) +
        coord_equal() +
        scale_color_viridis_c(limits = zlim) +
        labs(title = title, x = NULL, y = NULL, color = NULL) +
        theme_minimal(base_size = 11) +
        theme(panel.grid = element_blank(),
              plot.title = element_text(face="bold", size=12))
    }
    
    ed_dir <- as.data.frame(get.edgelist(g_dir))
    colnames(ed_dir) <- c("a", "b")
    ed_dir$type <- E(g_dir)$type
    ed_dir$ax <- V(g_dir)$x[match(ed_dir$a, V(g_dir)$name)]
    ed_dir$ay <- V(g_dir)$y[match(ed_dir$a, V(g_dir)$name)]
    ed_dir$bx <- V(g_dir)$x[match(ed_dir$b, V(g_dir)$name)]
    ed_dir$by <- V(g_dir)$y[match(ed_dir$b, V(g_dir)$name)]
    
    ed_T <- as.data.frame(get.edgelist(T0))
    colnames(ed_T) <- c("a", "b")
    ed_T$ax <- xy[ed_T$a, 1]; ed_T$ay <- xy[ed_T$a, 2]
    ed_T$bx <- xy[ed_T$b, 1]; ed_T$by <- xy[ed_T$b, 2]
    
    p_network <- ggplot() +
      geom_segment(data = subset(ed_dir, type == "road"),
                   aes(x = ax, y = ay, xend = bx, yend = by),
                   linewidth = 0.22, alpha = 0.22) +
      geom_segment(data = subset(ed_dir, type == "flyover"),
                   aes(x = ax, y = ay, xend = bx, yend = by),
                   linewidth = 0.70, alpha = 0.32) +
      geom_segment(data = ed_T,
                   aes(x = ax, y = ay, xend = bx, yend = by),
                   linewidth = 0.32, alpha = 0.55) +
      geom_point(data = df, aes(x = x, y = y), size = 0.50, alpha = 0.90) +
      coord_equal() +
      theme_minimal(base_size = 11) +
      theme(panel.grid = element_blank()) +
      labs(title = "One-way roads + flyovers (gray) with MST geometry overlay",
           x = NULL, y = NULL)
    
    plot_obj <- gridExtra::arrangeGrob(
      p_network,
      p_field("truth", "Truth"),
      p_field("sBAST", "Soft-BAST"),
      p_field("BAST",  "BAST"),
      p_field("GP",    "GP (Euclidean)"),
      p_field("BART",  "BART (axis-aligned)"),
      ncol = 3
    )
  }
  
  list(results = out, plot = plot_obj)
}


grid_sizes <- c(12, 14, 16, 18, 20, 22, 24, 26)
R <- 10

## Optional: plot only for one selected grid (first rep)
plot_grid <- "22x28"   # set to e.g. "22x28" or "" to disable
plot_grob <- NULL

rows <- list()
idx <- 1

for (gs in grid_sizes) {
  ## keep aspect similar to your original 22x28
  nr <- gs
  nc <- round(gs * 28/22)
  
  tag <- paste0(nr, "x", nc)
  cat("Running grid ", tag, " ...\n", sep="")
  
  for (r in seq_len(R)) {
    seed <- 10000 + 100*gs + r
    do_plot <- (tag == plot_grid && r == 1)
    one <- run_one_traffic(nr, nc, seed = seed, do_plot = do_plot)
    
    tmp <- one$results
    tmp$grid <- tag
    tmp$rep  <- r
    rows[[idx]] <- tmp
    idx <- idx + 1
    
    if (do_plot) plot_grob <- one$plot
  }
}

all_df <- do.call(rbind, rows)


methods <- c("sBAST","BAST","GP","BART")
grids <- unique(all_df$grid)

summ_rows <- list()
k <- 1
for (g in grids) {
  for (m in methods) {
    sub <- all_df[all_df$grid == g & all_df$method == m, ]
    summ_rows[[k]] <- data.frame(
      grid = g,
      method = m,
      mean_RMSE = mean(sub$RMSE),
      var_RMSE  = var(sub$RMSE),
      mean_sMAPE = mean(sub$sMAPE),
      var_sMAPE  = var(sub$sMAPE),
      mean_time = mean(sub$time_sec),
      stringsAsFactors = FALSE
    )
    k <- k + 1
  }
}
summary_df <- do.call(rbind, summ_rows)

## order grids by nr (parse before "x")
grid_n <- as.integer(sub("x.*$", "", summary_df$grid))
summary_df <- summary_df[order(grid_n, summary_df$method), ]

print(summary_df)


rmse_wide <- reshape(
  transform(summary_df,
            val = sprintf("%.3f (%.3f)", mean_RMSE, var_RMSE))[, c("grid","method","val")],
  timevar = "method", idvar = "grid", direction = "wide"
)
colnames(rmse_wide) <- sub("^val\\.", "", colnames(rmse_wide))
cat("\nRMSE table: mean (variance)\n")
print(rmse_wide)

smape_wide <- reshape(
  transform(summary_df,
            val = sprintf("%.3f (%.3f)", mean_sMAPE, var_sMAPE))[, c("grid","method","val")],
  timevar = "method", idvar = "grid", direction = "wide"
)
colnames(smape_wide) <- sub("^val\\.", "", colnames(smape_wide))
cat("\nsMAPE table: mean (variance)\n")
print(smape_wide)

time_wide <- reshape(
  transform(summary_df,
            val = sprintf("%.3f", mean_time))[, c("grid","method","val")],
  timevar = "method", idvar = "grid", direction = "wide"
)
colnames(time_wide) <- sub("^val\\.", "", colnames(time_wide))
cat("\nMean time (sec)\n")
print(time_wide)

if (!is.null(plot_grob)) {
  grid::grid.newpage()
  grid::grid.draw(plot_grob)
}


time_plot_df <- summary_df

## Extract approximate node count from "Nr x Nc"
time_plot_df$nr <- as.integer(sub("x.*$", "", time_plot_df$grid))
time_plot_df$nc <- as.integer(sub("^.*x", "", time_plot_df$grid))
time_plot_df$p  <- time_plot_df$nr * time_plot_df$nc


ggplot(time_plot_df,
       aes(x = p, y = mean_time, color = method)) +
  geom_line(linewidth = 1.1) +
  geom_point(size = 2) +
  scale_y_log10() +
  scale_color_brewer(palette = "Dark2") +
  labs(
    x = "Number of nodes (grid size)",
    y = "Mean runtime (seconds, log scale)",
    color = NULL,
    title = "Computational scaling across grid sizes"
  ) +
  theme_minimal(base_size = 13) +
  theme(
    panel.grid.minor = element_blank(),
    plot.title = element_text(face = "bold")
  )



## Extract GP baseline times
gp_time <- subset(summary_df, method == "GP")[, c("grid", "mean_time")]
colnames(gp_time)[2] <- "gp_time"

## Merge and compute ratio
time_ratio_df <- merge(summary_df, gp_time, by = "grid")
time_ratio_df$time_ratio <- time_ratio_df$mean_time / time_ratio_df$gp_time

## Add grid size info
time_ratio_df$nr <- as.integer(sub("x.*$", "", time_ratio_df$grid))
time_ratio_df$nc <- as.integer(sub("^.*x", "", time_ratio_df$grid))
time_ratio_df$p  <- time_ratio_df$nr * time_ratio_df$nc


time_ratio_wide <- reshape(
  transform(
    time_ratio_df,
    val = sprintf("%.2f×", time_ratio)
  )[, c("grid", "method", "val")],
  timevar = "method",
  idvar   = "grid",
  direction = "wide"
)

colnames(time_ratio_wide) <- sub("^val\\.", "", colnames(time_ratio_wide))

cat("\nMean time ratio relative to GP\n")
print(time_ratio_wide)


ggplot(
  subset(time_ratio_df, method != "GP"),
  aes(x = p, y = time_ratio, color = method)
) +
  geom_hline(yintercept = 1, linetype = "dashed", alpha = 0.6) +
  geom_line(linewidth = 1.1) +
  geom_point(size = 2) +
  scale_color_brewer(palette = "Dark2") +
  labs(
    x = "Number of nodes (grid size)",
    y = "Mean runtime / GP runtime",
    color = NULL,
    title = "Runtime relative to Gaussian Process baseline"
  ) +
  theme_minimal(base_size = 13) +
  theme(
    panel.grid.minor = element_blank(),
    plot.title = element_text(face = "bold")
  )

