# data_preparation.R
# This file handles data loading, generation, and cleaning for the conformal prediction experiment.

#' 
#' #' Load and clean external dataset
#' @param dataset_name Name of the dataset to load (e.g., "california", "bio", "bike", "kc")
#' @return List with x (feature matrix) and y (response vector)
load_external_data <- function(dataset_name) {
  if (dataset_name == "california") {
    data <- read.csv("data/california_housing.csv")
    numeric_cols <- sapply(data, is.numeric)
    numeric_data <- data[, numeric_cols, drop = FALSE]
    valid_rows <- apply(numeric_data, 1, function(row) all(!is.na(row) & !is.nan(row) & !is.infinite(row)))
    data <- data[valid_rows, , drop = FALSE]
    feature_cols <- setdiff(names(data), c("median_house_value", "ocean_proximity"))
    x <- as.matrix(data[, feature_cols, drop = FALSE])
    y <- data[["median_house_value"]]
    
  } else if (dataset_name == "bio") {
    data <- read.csv("data/CASP.csv")
    numeric_cols <- sapply(data, is.numeric)
    numeric_data <- data[, numeric_cols]
    valid_rows <- apply(numeric_data, 1, function(row) all(!is.na(row) & !is.nan(row) & !is.infinite(row)))
    data <- data[valid_rows, ]
    feature_cols <- names(data)[numeric_cols & names(data) != "RMSD"]
    x <- as.matrix(data[, feature_cols])
    y <- data$RMSD
    
  } else if (dataset_name == "bike") {
    data <- read.csv("data/bike_sharing.csv")
    numeric_cols <- sapply(data, is.numeric)
    numeric_data <- data[, numeric_cols]
    valid_rows <- apply(numeric_data, 1, function(row) all(!is.na(row) & !is.nan(row) & !is.infinite(row)))
    data <- data[valid_rows, ]
    feature_cols <- names(data)[numeric_cols & names(data) != "cnt"]
    x <- as.matrix(data[, feature_cols, drop = FALSE])
    y <- data$cnt
    
  } else if (dataset_name == "kc") {
    data <- read.csv("data/kc_house_data.csv")
    numeric_cols <- sapply(data, is.numeric)
    numeric_data <- data[, numeric_cols]
    valid_rows <- apply(numeric_data, 1, function(row) all(!is.na(row) & !is.nan(row) & !is.infinite(row)))
    data <- data[valid_rows, ]
    feature_cols <- setdiff(names(data), c("price", "id", "date"))
    x <- as.matrix(data[, feature_cols]) 
    y <- data$price 
    
  } else {
    stop("Unknown dataset name: ", dataset_name)
  }
  
  return(list(x = x, y = y))
}

#' Generate simulated data
#' @param n_samples Number of samples to generate
#' @param data_type Type of distribution ("normal", "mixture", or "multimodal_gaussian")
#' @param seed Random seed for reproducibility
#' @param d Number of features (for multimodal_gaussian)
#' @param n_components Number of mixture components (for multimodal_gaussian)
#' @param mu_base Base means for components (for multimodal_gaussian)
#' @param sigma Standard deviations for components (for multimodal_gaussian)
#' @param plot_dist If TRUE, plot the distribution of y (for multimodal_gaussian)
#' @return List with x (feature matrix) and y (response vector)
generate_simulated_data <- function(n_samples, 
                                    data_type = "multimodal_gaussian", 
                                    seed = 123, 
                                    d = 5, 
                                    plot_dist = FALSE) {
  set.seed(seed)
  
  if (data_type == "multimodal_gaussian") {
    # Multimodal Gaussian mixture distribution
    n_components <- 7
    mu_base <- c(-15, -10, -5, 0, 5, 10, 15)
    sigma <- c(1, 1.2, 1.5, 1, 1.5, 1.2, 1)
    
    x <- matrix(rnorm(n_samples * d, mean = 0, sd = 1), nrow = n_samples, ncol = d)
    beta <- matrix(rnorm(d * n_components, mean = 0, sd = 1), nrow = d, ncol = n_components)
    gamma <- matrix(rnorm(d * n_components, mean = 0, sd = 0.5), nrow = d, ncol = n_components)
    
    y <- numeric(n_samples)
    for (i in 1:n_samples) {
      logits <- as.numeric(x[i, ] %*% beta)
      weights <- exp(logits) / sum(exp(logits))  # Softmax normalization
      mu <- mu_base + as.numeric(x[i, ] %*% gamma)
      component <- sample(1:n_components, size = 1, prob = weights)
      y[i] <- rnorm(1, mean = mu[component], sd = sigma[component])
    }
    
    x <- scale(x)
    indices <- sample(1:n_samples)
    x <- x[indices, ]
    y <- y[indices]
    
    if (plot_dist) {
      hist(y, breaks = 100, main = "Distribution of y (Multimodal Gaussian)", 
           xlab = "y", col = "lightblue")
    }
  } else if (data_type == "simple_close_trimodal") {
    n_components <- 3
    mu_base <- c(0, 1.0, 2.0)
    sigma <- c(0.2, 0.2, 0.2)
    
    n_counts <- rmultinom(1, n_samples, rep(1/n_components, n_components))
    n_counts <- pmax(n_counts, 1)
    n_counts <- as.numeric(n_counts)
    n1 <- n_counts[1]; n2 <- n_counts[2]; n3 <- n_counts[3]
    cat("Sample counts per component:", n1, n2, n3, "\n")
    
    x <- matrix(runif(n_samples * d, min = -5, max = 5), nrow = n_samples, ncol = d)
    x <- scale(x)
    
    shift <- 0.1 * x[, 1]
    y <- c(rnorm(n1, mean = mu_base[1] + shift[1:n1], sd = sigma[1]),
           rnorm(n2, mean = mu_base[2] + shift[(n1+1):(n1+n2)], sd = sigma[2]),
           rnorm(n3, mean = mu_base[3] + shift[(n1+n2+1):(n1+n2+n3)], sd = sigma[3]))
    
    indices <- sample(1:n_samples)
    x <- x[indices, ]
    y <- y[indices]
    
    if (plot_dist) {
      hist(y, breaks = 50, main = "Conditional Trimodal Distribution of Y", 
           xlab = "Y", col = "lightblue", prob = TRUE)
      lines(density(y), col = "blue", lwd = 2)
    }
  } else {
    stop("Unsupported data_type: ", data_type)
  }
  
  return(list(x = x, y = y))
}

#' Main function to prepare data
#' @param data_source Type of data ("external" or "simulated")
#' @param n_samples Number of samples (if data_source = "simulated")
#' @param data_type Distribution type (if data_source = "simulated")
#' @param seed Random seed for simulated data
#' @param d Number of features 
#' @param plot_dist If TRUE, plot the distribution of y 
#' @return List with x_all (feature matrix) and y_all (response vector)
prepare_data <- function(data_source = "external",
                         dataset_name = NULL,
                         n_samples = 1000,
                         data_type = "normal",
                         seed = 123,
                         d = 5,
                         plot_dist = FALSE) {
  if (data_source == "external") {
    if (is.null(dataset_name)) {
      stop("For external data, 'dataset_name' must be provided")
    }
    data <- load_external_data(dataset_name)
  } else if (data_source == "simulated") {
    data <- generate_simulated_data(
      n_samples = n_samples,
      data_type = data_type,
      seed = seed,
      d = d,
      plot_dist = plot_dist
    )
  } else {
    stop("Invalid data_source: ", data_source, ". Choose 'external' or 'simulated'.")
  }
  
  return(list(x_all = data$x, y_all = data$y))
}