install.packages("sodavis")     
install.packages("RcppCNPy")
library(sodavis)
library(RcppCNPy)

detection <- function(truep, predp, d){
  intersection_len <- length(intersect(truep, predp))
  TPR <- intersection_len/ length(truep)
  FPR <- length(setdiff(predp, truep)) / (((d * (d - 1)) / 2) + d - length(truep))
  
  return(c(TPR = TPR, FPR = FPR))
}

evaluate <- function(x, y, h, tset){
  
  Rep <- dim(x)[1]
  n <- dim(x)[2]
  p <- dim(x)[3]
  
  results <- matrix(NA, nrow = Rep, ncol=2)
  colnames(results) <- c("TPR", "FPR")
  
  for (i in (1:Rep)){
    model <- s_soda(x[i,,], y[i,], H = h)  # H：切片數目，可依樣本大小或連續性調整
    pred_set <- (model$best_Term)  # 重要 main effect 與 2-way interaction

    results[i, ] <- detection(tset, pred_set, p)
    print(detection(tset, pred_set, p))
  }
  
  mean_tpr <- mean(results[, "TPR"])
  sd_tpr <- sd(results[, "TPR"])
  mean_fpr <- mean(results[, "FPR"])
  sd_fpr <- sd(results[, "FPR"])
  
  return(list(TM = mean_tpr, TS = sd_tpr, FM = mean_fpr, FS = sd_fpr))
}

########################
#   Only Main Effects  #
########################
X_arr_3d <- npyLoad("../data/only_main300_data_X.npy")
dim(X_arr_3d) <- c(50, 300, 150)
y_arr <- npyLoad("../data/only_main300_data_y.npy")
true_set <- c("1", "2", "3", "4")
results <- evaluate(X_arr_3d, y_arr, 2, true_set)
results

########################
#   Weak Main Effects  #
########################
X_arr_3d <- npyLoad("../data/weak_main300_data_X.npy")
dim(X_arr_3d) <- c(50, 300, 150)
y_arr <- npyLoad("../data/weak_main300_data_y.npy")
true_set <- c("1", "2", "3", "4")
results <- evaluate(X_arr_3d, y_arr, 2, true_set)
results

##############################
#   No Overlap Interaction   #
##############################
X_arr_3d <- npyLoad("../data/inter_no_overlap300_data_X.npy")
dim(X_arr_3d) <- c(50, 300, 150)
y_arr <- npyLoad("../data/inter_no_overlap300_data_y.npy")
true_set <- c("1", "2", "3", "4*5")
results <- evaluate(X_arr_3d, y_arr, 2, true_set)
results

##############################
#  Mild Overlap Interaction  #
##############################
X_arr_3d <- npyLoad("../data/inter_mild_overlap300_data_X.npy")
dim(X_arr_3d) <- c(50, 300, 150)
y_arr <- npyLoad("../data/inter_mild_overlap300_data_y.npy")
true_set <- c("1", "2", "3", "3*4")
results <- evaluate(X_arr_3d, y_arr, 2, true_set)
results

##############################
# Strong Overlap Interaction #
##############################
X_arr_3d <- npyLoad("../data/inter_strong_overlap300_data_X.npy")
dim(X_arr_3d) <- c(50, 300, 150)
y_arr <- npyLoad("../data/inter_strong_overlap300_data_y.npy")
true_set <- c("1", "2", "3", "2*3")
results <- evaluate(X_arr_3d, y_arr, 2, true_set)
results

########################
#   Only Interaction   #
########################
X_arr_3d <- npyLoad("../data/only_inter300_data_X.npy")
dim(X_arr_3d) <- c(50, 300, 150)
y_arr <- npyLoad("../data/only_inter300_data_y.npy")
true_set <- c("1*2", "3*4")
results <- evaluate(X_arr_3d, y_arr,2, true_set)
results
