
rm(list = ls())
library("ggplot2")
# load helper.R

# Function to generate random factors W
# Input:
  # K: number of factors (default is 3)
  # n: sample size (default is 1)
  # distribution: type of distribution for W's elements ("Uniform", "Gaussian", "Bernoulli"), or a custom function. Default is "Gaussian".
# Output:
  # An n x K matrix of factors from the specified distribution
W.function = function(K = 5, n = 1, distribution = "Gaussian"){
  if(distribution == "Uniform"){return(matrix(runif(n * K, -1, 1), nrow = n))}
  if(distribution == "Gaussian"){return(matrix(rnorm(n * K, 0, 1), nrow = n))}
  if(distribution == "Bernoulli"){return(matrix(rbinom(n * K, 1, 0.5), nrow = n))}
  if(is.function(distribution)){return(t(matrix(replicate(n, distribution), ncol = n)))}
}

setting = "Two layer neural network uninformative second layer" # "Linear", "First order interaction", "Blockwise first order interaction", "Second order interaction", "Two layer neural network", "Two layer neural network uninformative second layer" 
# Function to generate Y (outcome) based on factor values W
# Input
  # W: a vector of factor values
  # setting: function form of the outcome ("Linear", "First order interaction", "Blockwise first order interaction", "Second order interaction"), or a custom function. Default is "linear".
# Output:
  # The outcome Y calculated for each row of W
Y.function = function(W){
  if(setting == "Linear"){
    return(sum(W))                           # Linear effect (select)
  }
  # return((sum(W)^2 - sum(W^2)) / length(W)) # First-order interaction
  # return(W[1] * W[2])                       # First-order interaction
  # return(W[1] + W[1] * W[2])                  # First-order interaction (default)
  if(setting == "First order interaction"){
    return(W[1] * W[2] + W[1] * W[3] + W[2] * W[3]) # First-order interaction (select)
  }
  # return((W[1] * W[2] < 0))                 # Indicator for interaction sign
  # return(exp(W[1] * W[2]))                  # Exponential interaction
  # return(sin(W[1] * W[2]))                  # Sine interaction
  # return(1 / (1 + exp(W[1] * W[2])))          # Sigmoid interaction
  if(setting == "Blockwise first order interaction"){
    return(1 / (1 + exp((W[1] + W[2]) * 10)) + 1 / (1 + exp((W[2] + W[3]) * 10))) # Linear combination of sigmoid interactions  (select)
  }
  if(setting == "Second order interaction"){
    return(W[1] * W[2] * W[3]) # Second-order interaction (select)
  }
  if(setting == "Two layer neural network"){
    # Coefficients for the first layer
    beta1 = matrix(0, nrow = K1, ncol = K2)
    beta1[c(1, 2), 1] = 1; beta1[c(2, 3), 2] = -1
    alpha1 = rep(0, K2); alpha1[c(1, 2)] = 1
    # Coefficients for the second layer
    beta2 = matrix(0, nrow = K2, ncol = 1)
    beta2[1] = 1; beta2[2] = 1 
    alpha2 = -10
    
    W1 = W[seq(1, K1)]; E2 = W[-seq(1, K1)]
    # neuron1 = sigmoid((t(beta1) %*% W1) + alpha1) # First layer
    neuron1 = ((t(beta1) %*% W1) + alpha1)^2 # First layer
    Y = sum(sigmoid(beta2 * (E2 * 3 + neuron1))) + alpha2 # Second layer
    # Y = sum((beta2 * (W2 * 3 + neuron1))^2) + alpha2 # Second layer
    return(Y)
  }
  if(setting == "Two layer neural network uninformative second layer"){
    # The same as the setting "Two layer neural network" except that the original variance in W2, i.e., Var(E2), are decreased.
    # Coefficients for the first layer
    beta1 = matrix(0, nrow = K1, ncol = K2)
    beta1[c(1, 2), 1] = 1; beta1[c(2, 3), 2] = -1
    alpha1 = rep(0, K2); alpha1[c(1, 2)] = 1
    # Coefficients for the second layer
    beta2 = matrix(0, nrow = K2, ncol = 1)
    beta2[1] = 1; beta2[2] = 1 
    alpha2 = -10
    
    W1 = W[seq(1, K1)]; E2 = W[-seq(1, K1)]
    # neuron1 = sigmoid((t(beta1) %*% W1) + alpha1) # First layer
    neuron1 = ((t(beta1) %*% W1) + alpha1)^2 # First layer
    Y = sum(sigmoid(beta2 * (E2 * 0.5 + neuron1))) + alpha2 # Second layer
    # Y = sum((beta2 * (W2 * 3 + neuron1))^2) + alpha2 # Second layer
    return(Y)
    
  }
}


# K: Number of factors. K = 5 for two layer neural network, and K = 3 for the rest.
if(setting == "Two layer neural network"){
  K = 5
  K1 = 3; K2 = K - K1 # K1 (K2): the number of factors for the first (second) layer of the two layer neural network outcome
}else if(setting == "Two layer neural network uninformative second layer"){
  K = 5
  K1 = 3; K2 = K - K1 # K1 (K2): the number of factors for the first (second) layer of the two layer neural network outcome
}else{K = 3}

M = 1000 # Number of Monte Carlo samples. Default is 10000.
set.seed(318)
# Compute the total explanability.
total = explanability.union(W.function = W.function,
                            Y.function = Y.function,
                            factor.index = seq(1, K),
                            M = M)
# Explanabilities of a single factor.
singletons = lapply(seq(1, K), function(x){explanability.union(W.function = W.function,
                                                               Y.function = Y.function,
                                                               factor.index = x,
                                                               M = M)})
# Explanabilities of first-order interactions.
interactions.1st = lapply(combn(seq(1, K), 2, simplify = F), function(x){explanability.intersection(W.function = W.function,
                                                                                                 Y.function = Y.function,
                                                                                                 factor.index = x,
                                                                                                 M = M)})
result = list(total = total,
              singleton = unlist(singletons) / total, 
              interaction.1st = unlist(interactions.1st) / total)

if(setting != "Two layer neural network"){
  # K = 3
  # Explanabilities of first-order interactions.
  interactions.2nd = lapply(combn(seq(1, K), 3, simplify = F), function(x){explanability.intersection(W.function = W.function,
                                                                                                   Y.function = Y.function,
                                                                                                   factor.index = x,
                                                                                                   M = M)})
  result[["interaction.2nd"]] = unlist(interactions.2nd) / total
}else{
  # Explanabilities of each layer
  layers = lapply(list(layer1 = seq(1, K1), layer2 = seq(K1+1, K)), function(x){explanability.union(W.function = W.function,
                                                                                      Y.function = Y.function,
                                                                                      factor.index = x,
                                                                                      M = M)})
  result[["layer"]] = unlist(layers) / total
}

