# -------------------------------------------------------------------
# Covariates:
#   (1) log_rv at yesterday
#   (2) average log_rv of the last 5 days
#   (3) average log_rv of the last 21 days
# -------------------------------------------------------------------

source("data_loader_RV.R")
source("Functions_spci_RV.R")


library(ggplot2)


set.seed(2020)


# -----------------------------
# 1) Load and prepare the data
# -----------------------------
data_list <- load_RV_dataset("SPY_RV.csv")
X_full <- data_list$X_full
Y_full <- as.numeric(data_list$Y_full)
dates  <- data_list$dates

cat("Loaded rows (after lag construction):", nrow(X_full), "\n")

# -----------------------------
# 2) Train/test split 
# -----------------------------
n <- nrow(X_full)
n_train <- floor(0.8 * n)

X_train <- X_full[1:n_train, , drop = FALSE]
Y_train <- Y_full[1:n_train]

X_test  <- X_full[(n_train + 1):n, , drop = FALSE]
Y_test  <- Y_full[(n_train + 1):n]
dates_test <- dates[(n_train + 1):n]

cat("Train size:", n_train, " Test size:", n - n_train, "\n")

# -----------------------------
# 3) Fit bootstrap ensemble
# -----------------------------
cat("Fitting bootstrap ensemble models...\n")
fit_func_params <- list(
  n_estimators = 50,
  max_depth = 10
)

bootstrap_results <- fit_bootstrap_models_online(
  X_train = X_train,
  Y_train = Y_train,
  B = 50,
  fit_func_params = fit_func_params
)

models <- bootstrap_results$models
cat("Done. Number of base learners:", length(models), "\n")

# -----------------------------
# 4) Compare methods
# -----------------------------
alpha <- 0.1
past_window <- min(600, n_train)  

lag_L = 75
num_trees = 100
max_depth = 10

run_and_plot <- function(method_name, ribbon_fill = "grey70") {
  cat("\n============================\n")
  cat("Running method:", method_name, "\n")
  cat("============================\n")

  out <- compute_PIs_Ensemble_online(
    models = models,
    X_train = X_train, Y_train = Y_train,
    X_predict = X_test, Y_predict = Y_test,
    alpha = alpha,
    past_window = past_window,
    method = method_name,
    lag_L = lag_L,
    qrf_num_trees = num_trees,
    qrf_max_depth = max_depth
  )

  metrics <- get_results(out$PIs, Y_test)

  cat(sprintf("Coverage: %.2f%%\n", 100 * metrics$coverage))
  cat(sprintf("Avg. width: %.6f\n", metrics$width))

  plot_df <- data.frame(
    date = dates_test,
    y = Y_test,
    center = out$Ensemble_pred_interval_centers,
    lower = out$PIs$lower,
    upper = out$PIs$upper
  )

  p <- ggplot(plot_df, aes(x = date)) +
    geom_ribbon(aes(ymin = lower, ymax = upper), fill = ribbon_fill, alpha = 0.25) +
    geom_line(aes(y = y), linewidth = 0.4) +
    geom_line(aes(y = center), linetype = "dashed", linewidth = 0.4) +
    labs(
      title = paste0(method_name, " on RV (alpha=", alpha, ")"),
      x = NULL,
      y = "Daily return"
    ) +
    theme_minimal()

  print(p)

  return(list(out = out, metrics = metrics, plot_df = plot_df, plot = p))
}


CQACP_res <- run_and_plot("CQACP", ribbon_fill = "darkorange")

# -----------------------------
# 5) Summary table
# -----------------------------
summary_tbl <- data.frame(
  method = c("CQACP"),
  coverage = c(CQACP_res$metrics$coverage),
  avg_width = c(CQACP_res$metrics$width)
)


print(summary_tbl)
