# run_experiment.R
# This file contains the main logic for running the conformal prediction experiment,
# including data splitting, model fitting, prediction, and result computation.

# Load functions and setup
source("R/basic_functions.R")
source("R/experiment_setup.R")

# Load required packages
library(predictionBands)
library(FlexCoDE)
library(FNN)
library(ggplot2)
library(caret)  
library(dplyr)

# Initialize result storage
# For original cd-split
coverage_rates_original <- numeric(num_runs)
avg_interval_lengths_original <- numeric(num_runs)
avg_interval_counts_original <- numeric(num_runs)

# For smoothed cd-split (cd_fourier)
coverage_rates_smoothed <- numeric(num_runs)
avg_interval_lengths_smoothed <- numeric(num_runs)
avg_interval_counts_smoothed <- numeric(num_runs)

# For dist-split
coverage_rates_dist <- numeric(num_runs)
avg_interval_lengths_dist <- numeric(num_runs)
avg_interval_counts_dist <- numeric(num_runs)

# For hpd-split
coverage_rates_hpd <- numeric(num_runs)
avg_interval_lengths_hpd <- numeric(num_runs)
avg_interval_counts_hpd <- numeric(num_runs)



# Helper function to parse interval string with unions
parse_interval <- function(interval_str) {
  if (is.null(interval_str) || nchar(interval_str) == 0) {
    return(matrix(nrow = 0, ncol = 2))  
  }
  
  intervals <- unlist(strsplit(interval_str, "U"))
  
  bounds_list <- lapply(intervals, function(int) {
    cleaned <- gsub("[()]", "", int)
    as.numeric(unlist(strsplit(cleaned, ",")))
  })
  
  bounds_matrix <- do.call(rbind, bounds_list)
  colnames(bounds_matrix) <- c("lower", "upper")
  return(bounds_matrix)
}
# Main experiment loop
for (r in 1:num_runs) {
  #set.seed(42+r)
  
  train_indices <- sample(1:nrow(x_all), n)
  x <- x_all[train_indices, ,drop = FALSE ]
  y <- y_all[train_indices]
  
  
  test_indices <- sample((1:nrow(x_all))[-train_indices], n_new)
  xnew <- x_all[test_indices, ,drop = FALSE ]
  ynew <- y_all[test_indices]
  
  #scalerX <- preProcess(x, method = c("center", "scale"))

  #x <- predict(scalerX, x)
  #xnew <- predict(scalerX, xnew)

  #mean_y_train <- mean(abs(y))
  #y <- as.numeric(y) / mean_y_train
  #ynew <- as.numeric(ynew) / mean_y_train
  
  ## 1. cdsplit
  fit_cd <- fit_predictionBands(x, y, per_train = per_train, per_val = per_val, per_ths = per_ths)
  bands_original <- predict(fit_cd, xnew, type = "cd", alpha = 0.1)
  
  coverage_orig <- numeric(n_new)
  lengths_orig <- numeric(n_new)
  counts_orig <- numeric(n_new)
  
  for (i in 1:n_new) {
    intervals_i <- bands_original$intervals[[i]]  
    y_val <- ynew[i]
    in_band <- FALSE
    len_i <- 0
    
    if (is.null(intervals_i) || is.na(intervals_i) || grepl("NA", intervals_i, fixed = TRUE) || nchar(intervals_i) == 0) {
      next  
    }
    if (!is.null(intervals_i) && nchar(intervals_i) > 0) {
      bounds <- parse_interval(intervals_i)
      if (nrow(bounds) > 0) {
        for (j in 1:nrow(bounds)) {
          lower <- bounds[j, "lower"]
          upper <- bounds[j, "upper"]
          if (y_val >= lower && y_val <= upper) {
            in_band <- TRUE
            break  
          }
        }
        len_i <- sum(bounds[, "upper"] - bounds[, "lower"])
        counts_orig[i] <- nrow(bounds)
      } else {
        counts_orig[i] <- 0
      }
    } else {
      counts_orig[i] <- 0
    }
    coverage_orig[i] <- as.numeric(in_band)
    lengths_orig[i] <- len_i
  }
  coverage_rates_original[r] <- mean(coverage_orig)
  avg_interval_lengths_original[r] <- mean(lengths_orig)
  avg_interval_counts_original[r] <- mean(counts_orig)
  
  ## 2. dist-split
  #fit_cd <- fit_predictionBands(x, y, per_train = 0.4, per_val = 0.1, per_ths = 0.5)
  bands_dist <- predict(fit_cd, xnew, type = "dist", alpha = 0.1)

  coverage_dist <- numeric(n_new)
  lengths_dist <- numeric(n_new)
  counts_dist <- numeric(n_new)

  for (i in 1:n_new) {
    intervals_i <- bands_dist$intervals[[i]]  
    y_val <- ynew[i]
    in_band <- FALSE
    len_i <- 0
    if (is.null(intervals_i) || is.na(intervals_i) || grepl("NA", intervals_i, fixed = TRUE) || nchar(intervals_i) == 0) {
      next  
    }

    if (!is.null(intervals_i) && nchar(intervals_i) > 0) {
      bounds <- parse_interval(intervals_i)
      if (nrow(bounds) > 0) {
        for (j in 1:nrow(bounds)) {
          lower <- bounds[j, "lower"]
          upper <- bounds[j, "upper"]
          if (y_val >= lower && y_val <= upper) {
            in_band <- TRUE
            break  
          }
        }
        len_i <- sum(bounds[, "upper"] - bounds[, "lower"])
        counts_dist[i] <- nrow(bounds)
      } else {
        counts_dist[i] <- 0
      }
    } else {
      counts_dist[i] <- 0
    }
    coverage_dist[i] <- as.numeric(in_band)
    lengths_dist[i] <- len_i
  }
  coverage_rates_dist[r] <- mean(coverage_dist)
  avg_interval_lengths_dist[r] <- mean(lengths_dist)
  avg_interval_counts_dist[r] <- mean(counts_dist)

  ## 3. hpd-split
  #fit_cd <- fit_predictionBands(x, y, per_train = 0.4, per_val = 0.1, per_ths = 0.5)
  bands_hpd <- predict_hpd(fit_cd, xnew, type = "hpd", alpha = 0.1)

  coverage_hpd <- numeric(n_new)
  lengths_hpd <- numeric(n_new)
  counts_hpd <- numeric(n_new)

  for (i in 1:n_new) {
    intervals_i <- bands_hpd$intervals[[i]]  
    y_val <- ynew[i]
    in_band <- FALSE
    len_i <- 0

    if (is.null(intervals_i) || is.na(intervals_i) || grepl("NA", intervals_i, fixed = TRUE) || nchar(intervals_i) == 0) {
      next  
    }
    if (!is.null(intervals_i) && nchar(intervals_i) > 0) {
      bounds <- parse_interval(intervals_i)
      if (nrow(bounds) > 0) {
        for (j in 1:nrow(bounds)) {
          lower <- bounds[j, "lower"]
          upper <- bounds[j, "upper"]
          if (y_val >= lower && y_val <= upper) {
            in_band <- TRUE
            break  
          }
        }
        len_i <- sum(bounds[, "upper"] - bounds[, "lower"])
        counts_hpd[i] <- nrow(bounds)
      } else {
        counts_hpd[i] <- 0
      }
    } else {
      counts_hpd[i] <- 0
    }
    coverage_hpd[i] <- as.numeric(in_band)
    lengths_hpd[i] <- len_i
  }
  coverage_rates_hpd[r] <- mean(coverage_hpd)
  avg_interval_lengths_hpd[r] <- mean(lengths_hpd)
  avg_interval_counts_hpd[r] <- mean(counts_hpd)
  
  ## 4. scd-split
  fit_smooth <- fit_predictionBands_smooth(x, y, per_train = per_train, per_val = per_val, per_ths = per_ths, sigma = sigma_value)
  bands_smoothed <- predict_smooth(fit_smooth, xnew, type = "cd_fourier", sigma = sigma_value, alpha = 0.1)

  coverage_smoo <- numeric(n_new)
  lengths_smoo <- numeric(n_new)
  counts_smoo <- numeric(n_new)

  for (i in 1:n_new) {
    intervals_i <- bands_smoothed$intervals[[i]]
    y_val <- ynew[i]
    in_band <- FALSE
    len_i <- 0
    if (length(intervals_i) > 0) {
      for (interval in intervals_i) {
        if (y_val >= interval[1] && y_val <= interval[2]) {
          in_band <- TRUE
        }
        if (length(interval) == 2) {
          len_i <- len_i + (interval[2] - interval[1])
        }
      }
    }
    coverage_smoo[i] <- as.numeric(in_band)
    lengths_smoo[i] <- len_i
    counts_smoo[i] <- ifelse(length(intervals_i) > 0, length(intervals_i), 0)
  }
  coverage_rates_smoothed[r] <- mean(coverage_smoo)
  avg_interval_lengths_smoothed[r] <- mean(lengths_smoo[lengths_smoo != 0])
  avg_interval_counts_smoothed[r] <- mean(counts_smoo[counts_smoo != 0])
  
#   
#   #par(mfrow = c(2, 20))
#   #plot(bands_original, ynew)
#   #plot(bands_smoothed, ynew)
#   
  cat("Run", r, ":\n")
  cat("  cd-split Coverage =", coverage_rates_original[r],
      ", average length =", avg_interval_lengths_original[r],
      ", average number =", avg_interval_counts_original[r], "\n")
  cat("  dist Coverage =", coverage_rates_dist[r],
      ", average length =", avg_interval_lengths_dist[r],
      ", average number =", avg_interval_counts_dist[r], "\n")
  cat("  hpd Coverage =", coverage_rates_hpd[r],
      ", average length =", avg_interval_lengths_hpd[r],
      ", avegrage number =", avg_interval_counts_hpd[r], "\n")
  cat("  scd-split (cd_fourier, sigma =", sigma_value, ") average Coverage Rate =", coverage_rates_smoothed[r],
      ", average length =", avg_interval_lengths_smoothed[r],
      ", average number =", avg_interval_counts_smoothed[r], "\n\n")
}

# 10 num_runs
cat("After", num_runs, "experiments：\n")
cat("cd-split average Coverage Rate =", mean(coverage_rates_original), "\n")
cat("cd-split average length =", mean(avg_interval_lengths_original), "\n")
cat("cd-split average number =", mean(avg_interval_counts_original), "\n\n")

cat("dist-split average Coverage Rate =", mean(coverage_rates_dist), "\n")
cat("dist-split average length =", mean(avg_interval_lengths_dist), "\n")
cat("dist-split average number =", mean(avg_interval_counts_dist), "\n\n")

cat("hpd-split average Coverage Rate =", mean(coverage_rates_hpd), "\n")
cat("hpd-split average length =", mean(avg_interval_lengths_hpd), "\n")
cat("hpd-split average number =", mean(avg_interval_counts_hpd), "\n\n")

cat("scd-split (cd_fourier, sigma =", sigma_value, ") average Coverage Rate =", mean(coverage_rates_smoothed), "\n")
cat("scd-split (cd_fourier) average length =", mean(avg_interval_lengths_smoothed), "\n")
cat("scd-split (cd_fourier) average number =", mean(avg_interval_counts_smoothed), "\n")

# Optional: Save results to output directory
dir.create("output", showWarnings = FALSE)
timestamp <- format(Sys.time(), "%Y-%m-%d_%H-%M-%S")
saveRDS(list(
  original = list(
    coverage = coverage_rates_original,
    avg_coverage = mean(coverage_rates_original),
    lengths = avg_interval_lengths_original,
    avg_lengths = mean(avg_interval_lengths_original),
    counts = avg_interval_counts_original,
    avg_counts = mean(avg_interval_counts_original)
  ),
  dist = list(
    coverage = coverage_rates_dist,
    avg_coverage = mean(coverage_rates_dist),
    lengths = avg_interval_lengths_dist,
    avg_lengths = mean(avg_interval_lengths_dist),
    counts = avg_interval_counts_dist,
    avg_counts = mean(avg_interval_counts_dist)
  ),
  hpd = list(
    coverage = coverage_rates_hpd,
    avg_coverage = mean(coverage_rates_hpd),
    lengths = avg_interval_lengths_hpd,
    avg_lengths = mean(avg_interval_lengths_hpd),
    counts = avg_interval_counts_hpd,
    avg_counts = mean(avg_interval_counts_hpd)
  ),
  smoothed = list(
    coverage = coverage_rates_smoothed,
    avg_coverage = mean(coverage_rates_smoothed),
    lengths = avg_interval_lengths_smoothed,
    avg_lengths = mean(avg_interval_lengths_smoothed),
    counts = avg_interval_counts_smoothed,
    avg_counts = mean(avg_interval_counts_smoothed)
  )
), file = paste0("output/experiment_results_", timestamp, ".rds"))