# Functions_spci.R
# -------------------------------------------------------------------
# Core functions for:
#   - Bootstrap RF ensemble point prediction
#   - Online prediction intervals:
#       * EnbPI 
#       * SPCI 
#       * CQACP 
# -------------------------------------------------------------------

suppressPackageStartupMessages({
  library(ranger)
})

# -------------------------
# Helper: fit one RF model
# -------------------------
fit_rf_model <- function(X, Y, n_estimators = 50, max_depth = NULL) {
  df <- data.frame(Y = as.numeric(Y), X)
  ranger(
    dependent.variable.name = "Y",
    data = df,
    num.trees = n_estimators,
    max.depth = max_depth,
    #seed = 1
  )
}

# -----------------------------------------
# Fit bootstrap ensemble of RF base learners
# -----------------------------------------
fit_bootstrap_models_online <- function(X_train, Y_train, B = 25, fit_func_params) {
  n_train <- nrow(X_train)
  models <- vector("list", B)

  # Hyperparameters
  n_estimators <- ifelse(is.null(fit_func_params$n_estimators), 50, fit_func_params$n_estimators)
  max_depth    <- fit_func_params$max_depth

  for (b in seq_len(B)) {
    #set.seed(b)
    idx <- sample(seq_len(n_train), size = n_train, replace = TRUE)
    X_boot <- X_train[idx, , drop = FALSE]
    Y_boot <- Y_train[idx]

    models[[b]] <- fit_rf_model(X_boot, Y_boot, n_estimators = n_estimators, max_depth = max_depth)
  }

  return(list(models = models))
}

# --------------------------------
# Helper: ensemble mean prediction
# --------------------------------
predict_ensemble_mean <- function(models, X_new) {
  # Returns a numeric vector of length nrow(X_new)
  pred_mat <- sapply(models, function(m) {
    as.numeric(predict(m, data = data.frame(X_new))$predictions)
  })

  # If nrow(X_new) == 1, sapply returns a vector; otherwise a matrix
  if (is.null(dim(pred_mat))) {
    return(rep(mean(pred_mat), nrow(X_new)))
  }
  rowMeans(pred_mat)
}


# -----------------------------
# Helper: fit QRF for residuals
# -----------------------------
fit_qrf_residual_model <- function(Z, y, num_trees = 500, max_depth = NULL) {
  df <- data.frame(y = as.numeric(y), Z)
  ranger(
    dependent.variable.name = "y",
    data = df,
    num.trees = num_trees,
    max.depth = max_depth,
    quantreg = TRUE,
    keep.inbag = TRUE,
    #seed = 1
  )
}

# -------------------------------------------------------
# Helper: predict quantiles from a fitted ranger QRF model
# -------------------------------------------------------
predict_qrf_quantiles <- function(qrf_model, Z_new, taus) {
  pr <- predict(qrf_model, data = data.frame(Z_new), type = "quantiles", quantiles = taus)
  as.numeric(pr$predictions)
}

# ------------------------------------------------------------
# Quantile Calibration via Cornish-Fisher basis
# ------------------------------------------------------------
qcm_calibrate_curve <- function(taus, Q_raw) { 
  z <- qnorm(taus)
  X_qcm <- cbind( # Correspond to K = 4
    1,
    z,
    z^2 - 1,
    z^3 - 3 * z
  )

  # Solve least squares: minimize ||X beta - Q||
  # Use lm.fit for speed/robustness
  fit <- tryCatch({
    lm.fit(x = X_qcm, y = as.numeric(Q_raw))
  }, error = function(e) NULL)

  if (is.null(fit) || any(is.na(fit$coefficients))) {
    return(list(Q = Q_raw, ok = FALSE))
  }

  beta <- as.numeric(fit$coefficients)

  # Simple sanity check used in your earlier prototype: enforce positive "scale-like" slope
  if (length(beta) < 2 || is.na(beta[2]) || beta[2] <= 0) {
    return(list(Q = Q_raw, ok = FALSE))
  }

  Q_cal <- as.numeric(X_qcm %*% beta)

  # Optional monotone repair (prevents quantile crossing on the grid)
  Q_cal <- cummax(Q_cal)

  return(list(Q = Q_cal, ok = TRUE))
}

# ------------------------------------------------------------
# Choose the narrowest (1-alpha) interval via beta-shift search
# ------------------------------------------------------------
choose_narrowest_interval <- function(taus_grid, Q_grid, alpha, n_beta = 10) {
  beta_grid <- seq(0, alpha, length.out = n_beta)
  widths <- numeric(length(beta_grid))

  for (i in seq_along(beta_grid)) {
    b <- beta_grid[i]
    t_low <- b
    t_high <- 1 - alpha + b

    q_pair <- approx(
      x = taus_grid, y = Q_grid,
      xout = c(t_low, t_high),
      rule = 2, ties = "ordered"
    )$y

    widths[i] <- q_pair[2] - q_pair[1]
  }

  b_star <- beta_grid[which.min(widths)]
  q_pair_star <- approx(
    x = taus_grid, y = Q_grid,
    xout = c(b_star, 1 - alpha + b_star),
    rule = 2, ties = "ordered"
  )$y

  return(list(q_low = q_pair_star[1], q_high = q_pair_star[2], beta_star = b_star))
}

# ------------------------------------------------------------
# Main: compute online prediction intervals for the ensemble
# ------------------------------------------------------------
compute_PIs_Ensemble_online <- function(models,
                                       X_train, Y_train,
                                       X_predict, Y_predict,
                                       alpha = 0.1, past_window = 2000,
                                       method = c("EnbPI", "SPCI", "CQACP"),
                                       lag_L = 30,
                                       qrf_num_trees = 500,
                                       qrf_max_depth = NULL,
                                       quantile_grid = seq(0.005, 0.995, by = 0.005)) {
  method <- match.arg(method)

  n_train <- nrow(X_train)
  n_test  <- nrow(X_predict)

  # Ensemble centers
  center_train <- predict_ensemble_mean(models, X_train)
  center_test  <- predict_ensemble_mean(models, X_predict)

  # Residual history initialized from training
  res_hist <- as.numeric(Y_train) - as.numeric(center_train)

  PIs <- data.frame(lower = rep(NA_real_, n_test), upper = rep(NA_real_, n_test))

  for (t in seq_len(n_test)) {
    c_t <- center_test[t]

    # rolling residual window
    if (length(res_hist) > past_window) {
      res_window <- tail(res_hist, past_window)
    } else {
      res_window <- res_hist
    }

    # -------------------------
    # EnbPI: unconditional PI
    # -------------------------
    if (method == "EnbPI") {
      q_low  <- as.numeric(quantile(res_window, probs = alpha / 2, type = 8, na.rm = TRUE))
      q_high <- as.numeric(quantile(res_window, probs = 1 - alpha / 2, type = 8, na.rm = TRUE))
    }

    # -------------------------------------------------------
    # SPCI / CQACP: conditional residual quantile modeling
    # -------------------------------------------------------
    if (method %in% c("SPCI", "CQACP")) {

      # Need enough residuals to build lag features
      if (length(res_window) <= (lag_L + 5)) {
        # fallback to unconditional if not enough history
        q_low  <- as.numeric(quantile(res_window, probs = alpha / 2, type = 8, na.rm = TRUE))
        q_high <- as.numeric(quantile(res_window, probs = 1 - alpha / 2, type = 8, na.rm = TRUE))
      } else {
        emb <- embed(res_window, lag_L + 1)
        y_q <- emb[, 1]
        Z_q <- emb[, -1, drop = FALSE]

        # current conditioning vector uses most recent lags in res_window
        Z_cur <- matrix(rev(tail(res_window, lag_L)), nrow = 1)

        qrf <- fit_qrf_residual_model(Z = Z_q, y = y_q, num_trees = qrf_num_trees, max_depth = qrf_max_depth)

        Q_raw <- predict_qrf_quantiles(qrf, Z_new = Z_cur, taus = quantile_grid)

        Q_use <- Q_raw

        if (method == "CQACP") {
          cal <- qcm_calibrate_curve(quantile_grid, Q_raw)
          Q_use <- cal$Q
        }

        # Choose narrowest (1-alpha) mass interval over beta-shifts
        best <- choose_narrowest_interval(quantile_grid, Q_use, alpha = alpha, n_beta = 10)
        q_low <- best$q_low
        q_high <- best$q_high
      }
    }

    # Set PI for Y_t
    PIs$lower[t] <- c_t + q_low
    PIs$upper[t] <- c_t + q_high

    # After observing Y_predict[t], update residual history
    res_new <- as.numeric(Y_predict[t]) - c_t
    res_hist <- c(res_hist, res_new)
  }

  return(list(PIs = PIs, Ensemble_pred_interval_centers = center_test))
}

# -----------------------------
# Evaluation: coverage and width
# -----------------------------
get_results <- function(PIs, Y_predict) {
  y <- as.numeric(Y_predict)
  coverage <- mean(y >= PIs$lower & y <= PIs$upper, na.rm = TRUE)
  width <- mean(PIs$upper - PIs$lower, na.rm = TRUE)
  return(list(coverage = coverage, width = width))
}
