# Helper function for causal ANOVA
# This script contains functions for estimating explainabilities of unions and intersections of causally independent factors. For dependent factors, the user should first configure the underlying causally independent random variables that generate these dependent factors. Then, apply the explainability functions to the extracted independent factors.

sigmoid = function(x){
  return(exp(x) / (1 + exp(x)))
}


# Function to estimate the explainability of the union of a subset of factors via Monte Carlo. We implement the pick-and-freeze algorithm.
# Input:
# W.function: function to generate factors W
# Y.function: function to calculate outcome Y based on W
# factor.index: indices of factors included in the subset for union explainability
# M: number of Monte Carlo draws for approximation. Default is 100.
# Output:
# Estimated union explainability of the subset.
explanability.union = function(W.function, Y.function, factor.index, M = 100){
  result = replicate(M, explanability.union.helper(W.function = W.function, Y.function = Y.function, factor.index = factor.index))
  return(mean(result))
} 

# Helper function to compute explainability for a single pick-and-freeze realization
# Input: refer to the input of the function explanability.union.
# Output:
# The explainability estimate based on one Monte Carlo sample
explanability.union.helper = function(W.function, Y.function, factor.index){
  W1 = W.function(n = 1)
  W2 = W.function(n = 1)
  W2[-factor.index] = W1[-factor.index]  # Freeze factors not in factor.index
  value = (Y.function(W1) - Y.function(W2))^2 / 2
  return(value)
} 


# Function to calculate the explainability of the intersection of a subset of factors. We use the inclusion-exclusion principle to compute intersection explainabilities using the union explanabilities computed by the function explanability.union.
# Input: refer to the input of the function explanability.union.
# factor.index: indices of factors included in the subset for intersection explainability
# Output:
# Estimated intersection explainability of the subset.
explanability.intersection = function(W.function, Y.function, factor.index, M = 100){
  subset = unlist(lapply(1:length(factor.index), function(x) combn(factor.index, x, simplify = FALSE)), recursive = FALSE)
  result = sapply(subset, function(x){explanability.union(W.function = W.function, Y.function = Y.function, factor.index = x, M = M)})
  result = result * (-1)^(sapply(subset, length) - 1)  # Apply the inclusion-exclusion principle
  return(sum(result))
}
