
#' Function to sample uniformly from (-upper_bound, lower_bound) U (lower_bound, upper_bound)
sample_symmetric_uniform <- function(n, lower_bound = 0.5, upper_bound = 2) {
  interval <- (upper_bound - lower_bound)
  
  u <- runif(n, min = 0, max = interval * 2)
  samples <- ifelse(u < interval, u - upper_bound, u - upper_bound + 2*lower_bound)
  
  return(samples)
}

get_root_node_function <- function() {
  func <- function(parent_data, noise) {
    return(noise)
  }
  class(func) <- "root_node_function"
  return(func)
}

sample_linear_function <- function(n_parents, weight_func) {
  weights <- weight_func(n_parents)
  func <- function(parents, noise) {
    return(as.vector(parents %*% weights) + noise)
  }
  
  attr(func, "n_parents") <- n_parents
  attr(func, "weights") <- weights
  class(func) <- "linear_node_function"
  return(func)
}


create_linear_function <- function(weights) {
  func <- function(parent_values, noise) {
    return(as.vector(parent_values %*% weights) + noise)
  }
  
  attr(func, "n_parents") <- length(weights)
  attr(func, "weights") <- weights
  class(func) <- "linear_node_function"
  return(func)
}

print.linear_node_function <- function(x, ...) {
  n_parents <- attr(x, "n_parents")
  weights <- attr(x, "weights")
  cat("return(")
  for (i in seq_len(n_parents)) {
    cat(sprintf("x_%d * %f", i, weights[i]))
    if (i < n_parents) {
      cat(" + ")
    }
  }
  cat(" + noise)")
}

cat_distribution <- function(x, ...) {
  captured <- capture.output(x)
  cleaned_captured <- gsub("^\\[1\\] ", "", captured)
  cleaned_captured <- gsub("\"", "", cleaned_captured)
  cat(paste(cleaned_captured, collapse = "\n"))
}

#' Function to create functions of the form y = w_2 g(x^T w_1) + noise, where g
#' is a provided nonlinear function, w is a vector of weights, and noise is a random 
#' variable. The vector of weights is sampled from a provided weight function.
create_additive_nonlinear_function <- function(n_parents, g, w1_func = NULL, 
                                               w2_func = NULL) {
  if (is.null(w1_func)) {
    w_1 <- rep(1, n_parents)
  } 
  else {
    w_1 <- w1_func(n_parents)  
  }
  
  if (is.null(w2_func)) {
    w_2 <- 1
  }
  else {
    w_2 <- w2_func(1)
  }
  
  func <- function(parent_data, noise) {
    return(w_2 * g(as.vector(parent_data %*% w_1)) + noise)
  }
  
  attr(func, "g") <- g
  attr(func, "n_parents") <- n_parents
  attr(func, "w_1") <- w_1
  attr(func, "w_2") <- w_2
  class(func) <- "additive_nonlinear_node_function"
  return(func)
}

print.additive_nonlinear_node_function <- function(x, ...) {
  n_parents <- attr(x, "n_parents")
  w2 <- attr(x, "w2")
  g <- attr(x, "g")
  cat("return(")
  
  if (is.primitive(g)) {
    g_str <- deparse(substitute(g))
    g_name <- gsub("^\\.Primitive\\(\"([a-zA-Z0-9_]+)\"\\)$", "\\1", g_str)
    cat(g_name)
  } 
  else {
    cat(paste(deparse(g)))
  }
  cat("(")
  for (i in seq_len(n_parents)) {
    cat(sprintf("x_%d * %.2f", i, weights[i]))
    if (i < n_parents) {
      cat(" + ")
    }
  }
  cat(") + noise)")
}


sample_post_nonlinear_function <- function(n_parents, g, list_of_f_functions) {
  
  # Sample non-linear functions from provided list
  f <- sample(list_of_f_functions, n_parents, replace = TRUE)
  
  return(function(parents, noise) {
    # Apply each function f[[i]] to the corresponding parent and sum them up
    f_val <- sum(sapply(1:n_parents, function(i) f[[i]](parents[i])))
    
    # Apply the function g to the sum of the f functions and noise, then return the result
    return(g(f_val + noise))
  })
}
