library(expm)


##main function###############

#input total sample Y, M, A, X, with nk_set records the indices of data from K machines
#For example, for balanced data, nk_set <- split(1:N, rep(1:K,times=N/K)) 

f_Dist <- function(Y,M,A,X,nk_set){
  n_set <- sapply(1:length(nk_set), function(i)length(nk_set[[i]]))
  T_beta_avei <- T_gamma_avei <- NULL
  for (ik in 1:K) {
    nk_set_ik <- nk_set[[ik]]
    nk <- length(nk_set_ik)
    Yi <- Y[nk_set_ik]
    Mi <- M[nk_set_ik]
    Ai <- A[nk_set_ik]
    Xi <- X[nk_set_ik,]
    fit_Yi <- lm(Yi ~ -1 + Ai + Mi + Xi)
    fit_Mi <- lm(Mi ~ -1 + Ai + Xi)
    T_beta_avei[ik] <- coef(fit_Yi)["Mi"]/summary(fit_Yi)$coefficients["Mi", "Std. Error"]
    T_gamma_avei[ik] <- coef(fit_Mi)["Ai"]/summary(fit_Mi)$coefficients["Ai", "Std. Error"]
  }
  T_beta_ave <- sum(T_beta_avei*sqrt(n_set/sum(n_set)))
  T_gamma_ave <- sum(T_gamma_avei*sqrt(n_set/sum(n_set)))
  T_sobel_Dis <- T_beta_ave*T_gamma_ave/sqrt(T_beta_ave^2+T_gamma_ave^2)
  T_MaxP_Dis <- max(2-2*pnorm(abs(T_beta_ave)), 2-2*pnorm(abs(T_gamma_ave)))
  if(abs(T_sobel_Dis)>1.96){
    Sobels_result <- "The distributed Sobel's test rejects H0"
  }else{
    Sobels_result <- "The distributed Sobel's test does not reject H0"
  }
  if(T_MaxP_Dis<0.05){
    MaxP_result <- "The distributed MaxP test rejects H0"
  }else{
    MaxP_result <- "The distributed MaxP test does not reject H0"
  }
  return(c(Sobels_result, MaxP_result))
}


## settings ###################
p <- 3 #dimension of confouder
K <- 4 #the number of machines
sigma_X <- 0.5^abs(outer(1 : p, 1 : p, "-"))
sigma_X_sqrtm <- sqrtm(sigma_X)
b1 <- c(1,-0.5,1) #beta_X in eqution (7)
b2 <- c(0.5,1,-1) #gamma_X in eqution (8)
param_set <- list(
  list(beta = 0, gamma = 0),
  list(beta = 0.2, gamma = 0),
  list(beta = 0, gamma = 0.2),
  
  list(beta = 0.1, gamma = 0.1),
  list(beta = 0.05, gamma = 0.2),
  list(beta = 0.2, gamma = 0.05)
)
gamma_true <- param_set[[2]]$gamma
beta_true <- param_set[[2]]$beta


##4.1 Balanced data###########
N <- 2^11 #total sample size
nk_set <- split(1:N, rep(1:K,times=N/K))#split data into K set

X <- matrix(rnorm(N*p)+1,nrow = N)%*%sigma_X_sqrtm  
A <- rbinom(N, 1, 0.5)
epsilon_m <- rnorm(N, 0, 1)
epsilon_y <- rnorm(N, 0, 1)
M <- gamma_true * A + X%*%b1 + epsilon_m
Y <- 1 * A + beta_true * M + X%*%b2 + epsilon_y
T_Dis <- f_Dist(Y,M,A,X,nk_set)


##4.2 Unbalanced data###########

n_set <- floor(runif(1, 50, 150))
for (k in 2:K) {
  n_set[k] <- floor(runif(1, 50, 150))
}
N <- sum(n_set)#total sample size
nk_set <- split(1:sum(n_set),rep(1:K,times=n_set))

X <- matrix(rnorm(N*p)+1,nrow = N)%*%sigma_X_sqrtm  
A <- rbinom(N, 1, 0.5)
epsilon_m <- rnorm(N, 0, 1)
epsilon_y <- rnorm(N, 0, 1)
M <- gamma_true * A + X%*%b1 + epsilon_m
Y <- 1 * A + beta_true * M + X%*%b2 + epsilon_y

T_Dis <- f_Dist(Y,M,A,X,nk_set)
